import sys import os import pickle import sqlite3 import pandas as pd from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from dotenv import load_dotenv from google import genai import chromadb from typing import List, Dict env_path = os.path.join(os.path.dirname(__file__), '.env') load_dotenv(env_path) # Add parent dir to path so we can import from middleware sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from middleware.material_predictor import predict_material_needs app = FastAPI(title="Wafer Defect API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Robust paths for Docker/Hosting BASE_DIR = os.path.dirname(os.path.abspath(__file__)) DB_PATH = os.path.join(BASE_DIR, '..', 'middleware', 'wafer_control.db') MODEL_PATH = os.path.join(BASE_DIR, '..', 'middleware', 'material_model.pkl') CHROMA_PATH = os.path.join(BASE_DIR, 'chroma_db') # Ensure directories exist os.makedirs(CHROMA_PATH, exist_ok=True) DEFECT_COLORS = { 'Center': '#ef4444', 'Donut': '#f59e0b', 'Edge-Loc': '#10b981', 'Edge-Ring': '#3b82f6', 'Loc': '#8b5cf6', 'Random': '#ec4899', 'Scratch': '#06b6d4', 'Near-full': '#f97316', 'None': '#6b7280', 'Undetected': '#374151', } # Globally load data so we don't block requests df = pd.DataFrame() if os.path.exists(DB_PATH): print(f"Loading DB from {DB_PATH}...") conn = sqlite3.connect(DB_PATH) df = pd.read_sql_query("SELECT * FROM wafer_logs", conn) conn.close() df['scan_time'] = pd.to_datetime(df['scan_time']) df['scan_date'] = df['scan_time'].dt.date else: print(f"Warning: DB not found at {DB_PATH}. Dashboard will be empty.") # Setup Vector DB and LLM print(f"Connecting to ChromaDB at {CHROMA_PATH}...") try: chroma_client = chromadb.PersistentClient(path=CHROMA_PATH) collection = chroma_client.get_or_create_collection(name="semiconductor_knowledge") except Exception as e: print(f"Warning: Could not connect to ChromaDB collection. Error: {e}") collection = None print("Initializing Gemini API...") gemini_client = None if os.getenv("GEMINI_API_KEY"): gemini_client = genai.Client(api_key=os.getenv("GEMINI_API_KEY")) else: print("Warning: GEMINI_API_KEY not found in environment.") print("Loading ML model...") model_pkg = None if os.path.exists(MODEL_PATH): with open(MODEL_PATH, 'rb') as f: model_pkg = pickle.load(f) @app.get("/api/kpi") def get_kpis(): total_scans = len(df) fail_df = df[df['status'] == 'FAIL'] fail_count = len(fail_df) pass_count = len(df[df['status'] == 'PASS']) pass_rate = round((pass_count / total_scans) * 100, 1) if total_scans else 0 scrap_count = len(df[df['action'] == 'ROUTE_TO_SCRAP']) avg_waste = round(fail_df['material_wasted_pct'].mean(), 2) if fail_count else 0 avg_confidence = round(fail_df['confidence'].mean(), 2) if fail_count else 0 return { "total_scans": total_scans, "pass_count": pass_count, "pass_rate": pass_rate, "fail_count": fail_count, "fail_rate": round(100 - pass_rate, 1), "scrap_count": scrap_count, "avg_waste": avg_waste, "avg_confidence": avg_confidence } @app.get("/api/charts/defects") def get_defects(): fail_df = df[df['status'] == 'FAIL'] defect_counts = fail_df['defect_type'].value_counts().reset_index() defect_counts.columns = ['defect_type', 'count'] gt_counts = fail_df['ground_truth'].value_counts().reset_index() gt_counts.columns = ['ground_truth', 'count'] return { "predictions": defect_counts.to_dict(orient="records"), "ground_truth": gt_counts.head(15).to_dict(orient="records") } @app.get("/api/charts/waste") def get_waste(): fail_df = df[df['status'] == 'FAIL'] waste_by_type = fail_df.groupby('defect_type').agg( total_waste=('material_wasted_pct', lambda x: x.sum() / 100.0) ).reset_index().sort_values('total_waste', ascending=True) action_counts = df['action'].value_counts().reset_index() action_counts.columns = ['action', 'count'] return { "waste_by_type": waste_by_type.to_dict(orient="records"), "actions": action_counts.to_dict(orient="records") } @app.get("/api/charts/trends") def get_trends(): daily = df.groupby('scan_date').agg( scans=('id', 'count'), fails=('status', lambda x: (x == 'FAIL').sum()), waste=('material_wasted_pct', lambda x: x.sum() / 100.0) ).reset_index() daily['fail_rate'] = round((daily['fails'] / daily['scans']) * 100, 1) return { "dates": daily['scan_date'].astype(str).tolist(), "fail_rate": daily['fail_rate'].tolist(), "waste": daily['waste'].tolist() } @app.get("/api/model/status") def model_status(): if not model_pkg: return {"loaded": False} m = model_pkg['metrics'] imp = model_pkg['metrics']['importances'] imp_df = pd.DataFrame({'feature': list(imp.keys()), 'importance': list(imp.values())}) imp_df = imp_df.sort_values('importance', ascending=True).tail(10) return { "loaded": True, "metrics": {"r2": round(m['r2'], 4), "mae": round(m['mae'], 2)}, "importance": imp_df.to_dict(orient="records") } class PredictionRequest(BaseModel): scans: int fail_rate: float @app.post("/api/predict") def predict_waste(req: PredictionRequest): if not model_pkg: return {"error": "No model loaded"} fail_df = df[df['status'] == 'FAIL'] dist = fail_df['defect_type'].value_counts(normalize=True).to_dict() pred = predict_material_needs(model_pkg['model'], model_pkg['feature_cols'], req.scans, req.fail_rate / 100.0, dist) pred['fail_rate'] = req.fail_rate return pred class ChatMessage(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[ChatMessage] @app.post("/api/chat") def chat_with_bot(req: ChatRequest): if not gemini_client: return {"error": "Gemini API key not configured"} user_message = req.messages[-1].content if req.messages else "" # 1. RAG Retrieval from ChromaDB context_docs = "" if collection and user_message: try: results = collection.query(query_texts=[user_message], n_results=2) if results and results['documents'] and results['documents'][0]: context_docs = "\n".join(results['documents'][0]) except Exception as e: print(f"ChromaDB Query Error: {e}") # 2. Get Live Dashboard Context total_scans = len(df) fail_df = df[df['status'] == 'FAIL'] fail_count = len(fail_df) pass_rate = round(((total_scans - fail_count) / total_scans) * 100, 1) if total_scans else 0 top_defects = fail_df['defect_type'].value_counts().head(3).to_dict() live_kpis = f""" Current Dashboard State: - Total Wafers Scanned: {total_scans} - Current Pass Rate: {pass_rate}% - Total Defective Wafers: {fail_count} - Top Defect Types Right Now: {top_defects} """ # 3. Construct System Prompt system_instruction = f""" You are the 'Gorilla Semiconductors Engineering Assistant', an expert semiconductor manufacturing assistant. You help engineers understand dashboard data and troubleshoot wafer defects. Maintain a strictly professional, analytical, and authoritative engineering tone. Here is the LIVE DATA from the dashboard: {live_kpis} Here is retrieved technical context from our engineering database based on the user's query: {context_docs if context_docs else "No specific engineering docs retrieved."} Use the live data to answer questions about 'current status' or 'dashboard'. Use the engineering docs to answer questions about 'why' a defect happens. """ try: # Convert messages to format expected by google-genai contents = [] for msg in req.messages: role = "user" if msg.role == "user" else "model" contents.append( genai.types.Content(role=role, parts=[genai.types.Part.from_text(text=msg.content)]) ) response = gemini_client.models.generate_content( model='gemini-2.5-flash-lite', contents=contents, config=genai.types.GenerateContentConfig( system_instruction=system_instruction, temperature=0.3 ) ) return {"response": response.text} except Exception as e: print(f"Gemini API Error: {e}") return {"error": str(e)}