Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI main application for XRD Analysis Tool. | |
| Serves both the API endpoints and the static React frontend. | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse, JSONResponse, PlainTextResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pathlib import Path | |
| from typing import Dict, List | |
| import re | |
| from .model_inference import XRDModelInference | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="OpenAlphaDiffract", | |
| description="Automated crystallographic analysis of powder X-ray diffraction data", | |
| version="1.0.0", | |
| ) | |
| # CORS — allow all origins (same-origin on HF Spaces, open for embeds) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize model inference | |
| model_inference = XRDModelInference() | |
| async def startup_event(): | |
| """Load model on startup""" | |
| model_inference.load_model() | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return {"status": "healthy", "model_loaded": model_inference.is_loaded()} | |
| async def predict(data: dict): | |
| """ | |
| Predict XRD analysis from preprocessed data. | |
| Expects JSON payload: {"x": [2theta values], "y": [intensity values], "metadata": {...}} | |
| """ | |
| import time | |
| request_start = time.time() | |
| try: | |
| metadata = data.get("metadata", {}) | |
| request_id = metadata.get("timestamp", "unknown") | |
| filename = metadata.get("filename", "unknown") | |
| analysis_count = metadata.get("analysisCount", "unknown") | |
| x = data.get("x", []) | |
| y = data.get("y", []) | |
| if not x or not y: | |
| return JSONResponse(status_code=400, content={"error": "Missing x or y data"}) | |
| if len(x) != len(y): | |
| return JSONResponse( | |
| status_code=400, | |
| content={"error": "x and y arrays must have the same length"}, | |
| ) | |
| results = model_inference.predict(x, y) | |
| request_time = (time.time() - request_start) * 1000 | |
| if isinstance(results, dict): | |
| results["request_metadata"] = { | |
| "request_id": request_id, | |
| "filename": filename, | |
| "analysis_count": analysis_count, | |
| "processing_time_ms": request_time, | |
| } | |
| return JSONResponse( | |
| content=results, | |
| headers={ | |
| "Cache-Control": "no-cache, no-store, must-revalidate, private", | |
| "Pragma": "no-cache", | |
| "Expires": "0", | |
| "X-Request-ID": str(request_id), | |
| }, | |
| ) | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": f"Prediction failed: {str(e)}"}, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Example data endpoints | |
| # --------------------------------------------------------------------------- | |
| EXAMPLE_DATA_DIR = Path(__file__).parent.parent / "example_data" | |
| CRYSTAL_SYSTEM_NAMES = { | |
| "1": "Triclinic", | |
| "2": "Monoclinic", | |
| "3": "Orthorhombic", | |
| "4": "Tetragonal", | |
| "5": "Trigonal", | |
| "6": "Hexagonal", | |
| "7": "Cubic", | |
| } | |
| def _parse_example_metadata(filepath: Path) -> dict: | |
| """Extract metadata from the header lines of a .dif file.""" | |
| meta = { | |
| "filename": filepath.name, | |
| "material_id": None, | |
| "crystal_system": None, | |
| "crystal_system_name": None, | |
| "space_group": None, | |
| "wavelength": None, | |
| } | |
| with open(filepath, "r") as f: | |
| for line in f: | |
| line = line.strip() | |
| if ( | |
| line | |
| and not line.startswith("#") | |
| and not line.startswith("CELL") | |
| and not line.startswith("SPACE") | |
| and not line.lower().startswith("wavelength") | |
| ): | |
| break | |
| if m := re.search(r"Material ID:\s*(\S+)", line): | |
| meta["material_id"] = m.group(1) | |
| if m := re.search(r"Crystal System:\s*(\d+)", line): | |
| num = m.group(1) | |
| meta["crystal_system"] = num | |
| meta["crystal_system_name"] = CRYSTAL_SYSTEM_NAMES.get( | |
| num, f"Unknown ({num})" | |
| ) | |
| if m := re.search(r"SPACE GROUP:\s*(\d+)", line): | |
| meta["space_group"] = m.group(1) | |
| if m := re.search(r"wavelength:\s*([\d.]+)", line, re.IGNORECASE): | |
| meta["wavelength"] = m.group(1) | |
| return meta | |
| async def list_examples(): | |
| """List available example data files with metadata.""" | |
| if not EXAMPLE_DATA_DIR.exists(): | |
| return [] | |
| examples = [] | |
| for fp in sorted(EXAMPLE_DATA_DIR.glob("*.dif")): | |
| examples.append(_parse_example_metadata(fp)) | |
| return examples | |
| async def get_example(filename: str): | |
| """Return the raw text content of an example data file.""" | |
| if "/" in filename or "\\" in filename or ".." in filename: | |
| raise HTTPException(status_code=400, detail="Invalid filename") | |
| filepath = EXAMPLE_DATA_DIR / filename | |
| if not filepath.exists() or not filepath.is_file(): | |
| raise HTTPException(status_code=404, detail="Example file not found") | |
| return PlainTextResponse(filepath.read_text()) | |
| # --------------------------------------------------------------------------- | |
| # Static files and SPA support | |
| # --------------------------------------------------------------------------- | |
| frontend_dist = Path(__file__).parent.parent / "frontend" / "dist" | |
| if frontend_dist.exists(): | |
| app.mount( | |
| "/assets", | |
| StaticFiles(directory=str(frontend_dist / "assets")), | |
| name="assets", | |
| ) | |
| async def serve_spa(path: str): | |
| """Serve React SPA""" | |
| file_path = frontend_dist / path | |
| if file_path.is_file(): | |
| return FileResponse(file_path) | |
| return FileResponse(frontend_dist / "index.html") | |
| else: | |
| async def root(): | |
| return {"message": "Frontend not built. Run 'npm run build' in frontend/"} | |