| import streamlit as st |
| import torch |
| import clip |
| from PIL import Image |
| import os |
| import pandas as pd |
| from datetime import datetime |
| import torch.nn.functional as F |
| from typing import List |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| model, preprocess = clip.load("ViT-B/32", device=device) |
| model.eval() |
|
|
| |
| st.set_page_config(page_title="Few-Shot Fault Detection", layout="wide") |
| st.title("🛠️ Few-Shot Fault Detection (Industrial Quality Control)") |
|
|
| st.markdown(""" |
| This demo uses the **smaller `ViT-B/32` encoder** from OpenAI's CLIP model to classify test images as **Nominal** or **Defective**, based on few-shot learning using user-provided reference images. |
| |
| ⚠️ **Note**: This app is running on a **free CPU tier** and is meant for demonstration purposes. For more advanced use cases, including GPU acceleration, custom training, and larger models, please refer to: |
| |
| 📄 [Megahed et al. (2025)](https://arxiv.org/abs/2501.12596): |
| *Adapting OpenAI's CLIP Model for Few-Shot Image Inspection in Manufacturing Quality Control: An Expository Case Study with Multiple Application Examples* |
| |
| 🔗 [GitHub & Colab links available in the paper](https://arxiv.org/abs/2501.12596) |
| """) |
|
|
| |
| def few_shot_fault_classification( |
| test_images: List[Image.Image], |
| test_image_filenames: List[str], |
| nominal_images: List[Image.Image], |
| nominal_descriptions: List[str], |
| defective_images: List[Image.Image], |
| defective_descriptions: List[str], |
| num_few_shot_nominal_imgs: int, |
| file_path: str = '.', |
| file_name: str = 'image_classification_results.csv', |
| print_one_liner: bool = False |
| ): |
| if not isinstance(test_images, list): test_images = [test_images] |
| if not isinstance(test_image_filenames, list): test_image_filenames = [test_image_filenames] |
| if not isinstance(nominal_images, list): nominal_images = [nominal_images] |
| if not isinstance(nominal_descriptions, list): nominal_descriptions = [nominal_descriptions] |
| if not isinstance(defective_images, list): defective_images = [defective_images] |
| if not isinstance(defective_descriptions, list): defective_descriptions = [defective_descriptions] |
|
|
| csv_file = os.path.join(file_path, file_name) |
| results = [] |
|
|
| with torch.no_grad(): |
| nominal_features = torch.stack([model.encode_image(img.unsqueeze(0)).squeeze(0).to(device) for img in nominal_images]) |
| nominal_features /= nominal_features.norm(dim=-1, keepdim=True) |
|
|
| defective_features = torch.stack([model.encode_image(img.unsqueeze(0)).squeeze(0).to(device) for img in defective_images]) |
| defective_features /= defective_features.norm(dim=-1, keepdim=True) |
|
|
| csv_data = [] |
|
|
| for idx, test_img in enumerate(test_images): |
| test_features = model.encode_image(test_img.unsqueeze(0)).squeeze(0).to(device) |
| test_features /= test_features.norm(dim=-1, keepdim=True) |
|
|
| max_nom_sim, max_def_sim = -float('inf'), -float('inf') |
| max_nom_idx, max_def_idx = -1, -1 |
|
|
| for i in range(nominal_features.shape[0]): |
| sim = (test_features @ nominal_features[i].T).item() |
| if sim > max_nom_sim: |
| max_nom_sim, max_nom_idx = sim, i |
|
|
| for j in range(defective_features.shape[0]): |
| sim = (test_features @ defective_features[j].T).item() |
| if sim > max_def_sim: |
| max_def_sim, max_def_idx = sim, j |
|
|
| similarities = torch.tensor([max_nom_sim, max_def_sim]) |
| probabilities = F.softmax(similarities, dim=0).tolist() |
| prob_nom, prob_def = probabilities |
|
|
| classification = "Defective" if prob_def > prob_nom else "Nominal" |
|
|
| csv_data.append({ |
| "datetime_of_operation": datetime.now().isoformat(), |
| "num_few_shot_nominal_imgs": num_few_shot_nominal_imgs, |
| "image_path": test_image_filenames[idx], |
| "image_name": test_image_filenames[idx].split('/')[-1], |
| "classification_result": classification, |
| "non_defect_prob": round(prob_nom, 3), |
| "defect_prob": round(prob_def, 3), |
| "nominal_description": nominal_descriptions[max_nom_idx], |
| "defective_description": defective_descriptions[max_def_idx] if defective_images else "N/A" |
| }) |
|
|
| if print_one_liner: |
| print(f"{test_image_filenames[idx]} classified as {classification} " |
| f"(Nominal Prob: {prob_nom:.3f}, Defective Prob: {prob_def:.3f})") |
|
|
| file_exists = os.path.isfile(csv_file) |
| with open(csv_file, mode='a' if file_exists else 'w', newline='') as file: |
| import csv |
| fieldnames = [ |
| "datetime_of_operation", "num_few_shot_nominal_imgs", "image_path", "image_name", |
| "classification_result", "non_defect_prob", "defect_prob", |
| "nominal_description", "defective_description" |
| ] |
| writer = csv.DictWriter(file, fieldnames=fieldnames) |
| if not file_exists: |
| writer.writeheader() |
| for row in csv_data: |
| writer.writerow(row) |
|
|
| return "" |
|
|
| |
| if 'nominal_images' not in st.session_state: |
| st.session_state.nominal_images = [] |
| if 'defective_images' not in st.session_state: |
| st.session_state.defective_images = [] |
| if 'test_images' not in st.session_state: |
| st.session_state.test_images = [] |
| if 'results' not in st.session_state: |
| st.session_state.results = [] |
|
|
| |
| tab1, tab2, tab3 = st.tabs(["📥 Upload Reference Images", "🔍 Test Classification", "📊 Results"]) |
|
|
| |
| with tab1: |
| st.header("Upload Reference Images") |
| nominal_files = st.file_uploader("Upload Nominal Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg']) |
| defective_files = st.file_uploader("Upload Defective Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg']) |
|
|
| if nominal_files: |
| st.session_state.nominal_images = [preprocess(Image.open(file).convert("RGB")).to(device) for file in nominal_files] |
| st.session_state.nominal_descriptions = [file.name for file in nominal_files] |
| st.success(f"Uploaded {len(nominal_files)} nominal images.") |
|
|
| if defective_files: |
| st.session_state.defective_images = [preprocess(Image.open(file).convert("RGB")).to(device) for file in defective_files] |
| st.session_state.defective_descriptions = [file.name for file in defective_files] |
| st.success(f"Uploaded {len(defective_files)} defective images.") |
|
|
| |
| with tab2: |
| st.header("Upload Test Image(s)") |
| test_files = st.file_uploader("Upload Test Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg']) |
|
|
| if st.button("🔍 Run Classification") and test_files: |
| test_images = [preprocess(Image.open(file).convert("RGB")).to(device) for file in test_files] |
| test_filenames = [file.name for file in test_files] |
|
|
| few_shot_fault_classification( |
| test_images=test_images, |
| test_image_filenames=test_filenames, |
| nominal_images=st.session_state.nominal_images, |
| nominal_descriptions=st.session_state.nominal_descriptions, |
| defective_images=st.session_state.defective_images, |
| defective_descriptions=st.session_state.defective_descriptions, |
| num_few_shot_nominal_imgs=len(st.session_state.nominal_images), |
| file_path=".", |
| file_name="streamlit_results.csv", |
| print_one_liner=False |
| ) |
|
|
| st.success("Classification complete!") |
| st.session_state.results = "streamlit_results.csv" |
|
|
| |
| with tab3: |
| st.header("Classification Results") |
| if os.path.exists("streamlit_results.csv"): |
| df = pd.read_csv("streamlit_results.csv") |
| st.dataframe(df) |
| st.download_button("📥 Download Results", data=df.to_csv(index=False), file_name="classification_results.csv", mime="text/csv") |
| else: |
| st.info("No results yet. Please classify some test images.") |
|
|