Amogh1221 commited on
Commit
8fdd265
Β·
verified Β·
1 Parent(s): 2128c6d

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +20 -0
  2. main.py +201 -0
  3. requirements.txt +6 -0
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install dependencies
6
+ COPY requirements.txt .
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ # Copy all code and artifacts
10
+ COPY . .
11
+
12
+ # Ensure NLTK data is downloaded (though main.py does this on startup,
13
+ # it's better to bake it in for faster startup if possible,
14
+ # but main.py handles it dynamically).
15
+
16
+ EXPOSE 7860
17
+ # HF Spaces use port 7860 by default for some SDKs, but for Docker we can use any
18
+ # but HF prefers 7860.
19
+
20
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ backend/main.py
3
+
4
+ Run:
5
+ pip install -r requirements.txt
6
+ uvicorn main:app --reload --port 8000
7
+ """
8
+
9
+ import os
10
+ import re
11
+ import pickle
12
+ import time
13
+ from contextlib import asynccontextmanager
14
+ from typing import Optional
15
+
16
+ import nltk
17
+ from fastapi import FastAPI, HTTPException
18
+ from fastapi.middleware.cors import CORSMiddleware
19
+ from pydantic import BaseModel, Field, field_validator
20
+
21
+ # ── NLTK setup ────────────────────────────────────────────────────────────────
22
+ for _pkg, _path in [
23
+ ("stopwords", "corpora/stopwords"),
24
+ ("punkt_tab", "tokenizers/punkt_tab"),
25
+ ("wordnet", "corpora/wordnet"),
26
+ ]:
27
+ try:
28
+ nltk.data.find(_path)
29
+ except LookupError:
30
+ nltk.download(_pkg, quiet=True)
31
+
32
+ from nltk.corpus import stopwords
33
+ from nltk.stem import WordNetLemmatizer
34
+
35
+ _STOP_WORDS = nltk.corpus.stopwords.words("english")
36
+ _LEMMATIZER = WordNetLemmatizer()
37
+
38
+
39
+ # ── cleaning() β€” exact copy from notebook cell 12 ────────────────────────────
40
+ def cleaning(text: str) -> str:
41
+ preprocessed = str(text).lower()
42
+ preprocessed = re.sub(r"[^a-zA-Z\s]", "", preprocessed)
43
+ words = nltk.word_tokenize(preprocessed)
44
+ filtered_words = [word for word in words if word not in _STOP_WORDS]
45
+ filtered_words = [_LEMMATIZER.lemmatize(word) for word in filtered_words]
46
+ return " ".join(filtered_words)
47
+
48
+
49
+ # ── Artifact loading ──────────────────────────────────────────────────────────
50
+ ARTIFACT_DIR = os.getenv("ARTIFACT_DIR", "./artifacts")
51
+
52
+ MODEL = None
53
+ VECTORIZER = None
54
+ ENCODER = None
55
+
56
+
57
+ def _load(fname: str):
58
+ path = os.path.join(ARTIFACT_DIR, fname)
59
+ if not os.path.exists(path):
60
+ raise FileNotFoundError(
61
+ f"Artifact not found: {path}\n"
62
+ f"Unzip model.zip into {ARTIFACT_DIR}/ first."
63
+ )
64
+ with open(path, "rb") as f:
65
+ return pickle.load(f)
66
+
67
+
68
+ @asynccontextmanager
69
+ async def lifespan(app: FastAPI):
70
+ global MODEL, VECTORIZER, ENCODER
71
+ print(f"Loading artifacts from: {ARTIFACT_DIR}")
72
+ MODEL = _load("model.pkl")
73
+ VECTORIZER = _load("tfidf.pkl")
74
+ ENCODER = _load("encoder.pkl")
75
+ print(f"Model loaded βœ“ | {type(MODEL).__name__} | Classes: {list(ENCODER.classes_)}")
76
+ yield
77
+ print("Shutting down.")
78
+
79
+
80
+ # ── App ───────────────────────────────────────────────────────────────────────
81
+ app = FastAPI(
82
+ title="Mental Health Sentiment Analysis API",
83
+ version="1.0.0",
84
+ lifespan=lifespan,
85
+ )
86
+
87
+ app.add_middleware(
88
+ CORSMiddleware,
89
+ allow_origins=["*"],
90
+ allow_credentials=True,
91
+ allow_methods=["*"],
92
+ allow_headers=["*"],
93
+ )
94
+
95
+
96
+ # ── Schemas ───────────────────────────────────────────────────────────────────
97
+ class PredictRequest(BaseModel):
98
+ text: str = Field(..., min_length=3, max_length=5000)
99
+
100
+ @field_validator("text")
101
+ @classmethod
102
+ def strip_text(cls, v: str) -> str:
103
+ return v.strip()
104
+
105
+
106
+ class ClassProbability(BaseModel):
107
+ label: str
108
+ probability: float
109
+
110
+
111
+ class PredictResponse(BaseModel):
112
+ label: str
113
+ confidence: float
114
+ probabilities: list[ClassProbability]
115
+ cleaned_input: str
116
+ latency_ms: float
117
+
118
+
119
+ class BatchPredictRequest(BaseModel):
120
+ texts: list[str] = Field(..., min_length=1, max_length=50)
121
+
122
+
123
+ class BatchPredictResponse(BaseModel):
124
+ results: list[PredictResponse]
125
+ total_latency_ms: float
126
+
127
+
128
+ class HealthResponse(BaseModel):
129
+ status: str
130
+ model_loaded: bool
131
+ model_type: Optional[str] = None
132
+ classes: Optional[list[str]] = None
133
+
134
+
135
+ # ── Core inference ────────────────────────────────────────────────────────────
136
+ def _infer(text: str) -> PredictResponse:
137
+ t0 = time.perf_counter()
138
+
139
+ cleaned = cleaning(text)
140
+ if not cleaned.strip():
141
+ raise HTTPException(status_code=422, detail="Text is empty after preprocessing.")
142
+
143
+ vec = VECTORIZER.transform([cleaned])
144
+ pred_idx = MODEL.predict(vec)[0]
145
+ label = ENCODER.inverse_transform([pred_idx])[0]
146
+ proba = MODEL.predict_proba(vec)[0]
147
+ confidence = float(proba[pred_idx])
148
+
149
+ probs_sorted = [
150
+ ClassProbability(label=cls, probability=round(float(p), 4))
151
+ for cls, p in sorted(
152
+ zip(ENCODER.classes_, proba),
153
+ key=lambda x: x[1],
154
+ reverse=True,
155
+ )
156
+ ]
157
+
158
+ return PredictResponse(
159
+ label = label,
160
+ confidence = round(confidence, 4),
161
+ probabilities = probs_sorted,
162
+ cleaned_input = cleaned,
163
+ latency_ms = round((time.perf_counter() - t0) * 1000, 2),
164
+ )
165
+
166
+
167
+ # ── Routes ────────────────────────────────────────────────────────────────────
168
+ @app.get("/", response_model=HealthResponse)
169
+ def health():
170
+ return HealthResponse(
171
+ status = "ok",
172
+ model_loaded = MODEL is not None,
173
+ model_type = type(MODEL).__name__ if MODEL else None,
174
+ classes = list(ENCODER.classes_) if ENCODER else None,
175
+ )
176
+
177
+
178
+ @app.post("/predict", response_model=PredictResponse)
179
+ def predict(req: PredictRequest):
180
+ if MODEL is None:
181
+ raise HTTPException(status_code=503, detail="Model not loaded.")
182
+ return _infer(req.text)
183
+
184
+
185
+ @app.post("/predict/batch", response_model=BatchPredictResponse)
186
+ def predict_batch(req: BatchPredictRequest):
187
+ if MODEL is None:
188
+ raise HTTPException(status_code=503, detail="Model not loaded.")
189
+ t0 = time.perf_counter()
190
+ results = [_infer(t) for t in req.texts]
191
+ return BatchPredictResponse(
192
+ results = results,
193
+ total_latency_ms = round((time.perf_counter() - t0) * 1000, 2),
194
+ )
195
+
196
+
197
+ @app.get("/classes")
198
+ def get_classes():
199
+ if ENCODER is None:
200
+ raise HTTPException(status_code=503, detail="Model not loaded.")
201
+ return {"classes": list(ENCODER.classes_)}
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi>=0.111.0
2
+ uvicorn[standard]>=0.29.0
3
+ pydantic>=2.6.0
4
+ scikit-learn>=1.4.0
5
+ nltk>=3.8.1
6
+ numpy>=1.26.0