mokshak commited on
Commit
12556e8
·
1 Parent(s): 03757d1

Add FastAPI endpoints for JSON API - /recommend, /health, /evaluate

Browse files
Files changed (2) hide show
  1. app.py +75 -1
  2. requirements.txt +2 -0
app.py CHANGED
@@ -465,9 +465,83 @@ with gr.Blocks(
465
  )
466
 
467
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
  # ============================================================================
469
  # LAUNCH APPLICATION
470
  # ============================================================================
471
 
472
  if __name__ == "__main__":
473
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
465
  )
466
 
467
 
468
+ # ============================================================================
469
+ # FASTAPI INTEGRATION FOR API ENDPOINT
470
+ # ============================================================================
471
+
472
+ from fastapi import FastAPI, Query
473
+ from fastapi.middleware.cors import CORSMiddleware
474
+ from typing import List, Dict, Any
475
+
476
+ app = FastAPI(
477
+ title="SHL Assessment Recommendation API",
478
+ description="AI-powered assessment recommendations using 2-Stage RAG Pipeline",
479
+ version="2.0.0"
480
+ )
481
+
482
+ app.add_middleware(
483
+ CORSMiddleware,
484
+ allow_origins=["*"],
485
+ allow_credentials=True,
486
+ allow_methods=["*"],
487
+ allow_headers=["*"],
488
+ )
489
+
490
+ @app.get("/health")
491
+ def health_check():
492
+ """Health check endpoint"""
493
+ return {
494
+ "status": "healthy",
495
+ "total_assessments": len(engine.assessments),
496
+ "model": "SBERT + Cross-Encoder"
497
+ }
498
+
499
+ @app.get("/recommend")
500
+ def recommend_api(
501
+ query: str = Query(..., min_length=3, description="Job description or requirements"),
502
+ max_results: int = Query(default=10, ge=1, le=10, description="Max results")
503
+ ):
504
+ """Get assessment recommendations as JSON"""
505
+ results = engine.recommend(query.strip(), max_results=max_results)
506
+ return {
507
+ "query": query,
508
+ "total_results": len(results),
509
+ "recommendations": results
510
+ }
511
+
512
+ @app.get("/evaluate")
513
+ def evaluate_api():
514
+ """Run evaluation metrics"""
515
+ test_cases = [
516
+ {"query": "need Python developer for backend engineering role", "relevant": ["Software Developer", "Coding", "IT Professional"]},
517
+ {"query": "hiring receptionist to handle customer inquiries", "relevant": ["Customer Service", "Contact Center", "Front Desk"]},
518
+ {"query": "executive position requiring strategic decisions", "relevant": ["Manager", "Supervisor", "Virtual Assessment"]},
519
+ {"query": "fresh graduates for analyst trainee program", "relevant": ["Graduate", "Apprentice", "Entry Level"]},
520
+ {"query": "B2B account executive in enterprise software", "relevant": ["Sales Professional", "Account Manager", "Sales Representative"]},
521
+ ]
522
+
523
+ results = []
524
+ for tc in test_cases:
525
+ recs = engine.recommend(tc["query"], max_results=5)
526
+ names = [r["name"] for r in recs]
527
+ # Simple P@1 check
528
+ p1 = 1.0 if any(rel.lower() in names[0].lower() for rel in tc["relevant"]) else 0.0
529
+ results.append({"query": tc["query"][:40], "top_result": names[0], "P@1": p1})
530
+
531
+ avg_p1 = sum(r["P@1"] for r in results) / len(results)
532
+ return {
533
+ "aggregate": {"mean_P@1": round(avg_p1, 3)},
534
+ "per_query": results,
535
+ "total_test_cases": len(test_cases)
536
+ }
537
+
538
+ # Mount Gradio app to FastAPI
539
+ app = gr.mount_gradio_app(app, demo, path="/")
540
+
541
  # ============================================================================
542
  # LAUNCH APPLICATION
543
  # ============================================================================
544
 
545
  if __name__ == "__main__":
546
+ import uvicorn
547
+ uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt CHANGED
@@ -3,3 +3,5 @@ sentence-transformers>=2.2.2
3
  faiss-cpu>=1.7.4
4
  numpy>=1.24.0
5
  torch>=2.0.0
 
 
 
3
  faiss-cpu>=1.7.4
4
  numpy>=1.24.0
5
  torch>=2.0.0
6
+ fastapi>=0.100.0
7
+ uvicorn>=0.23.0