junaid17 commited on
Commit
da6e212
·
verified ·
1 Parent(s): 0ec91c9

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +187 -0
  2. goemotions_bilstm_checkpoint.pth +3 -0
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import re
7
+
8
+ # ===============================
9
+ # App Init
10
+ # ===============================
11
+ app = FastAPI(title="GoEmotions Sentiment API", version="1.0")
12
+
13
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+
16
+ # ===============================
17
+ # Emotion Mapping
18
+ # ===============================
19
+ emotion_map = [
20
+ "admiration","amusement","anger","annoyance","approval","caring","confusion",
21
+ "curiosity","desire","disappointment","disapproval","disgust","embarrassment",
22
+ "excitement","fear","gratitude","grief","joy","love","nervousness","optimism",
23
+ "pride","realization","relief","remorse","sadness","surprise","neutral"
24
+ ]
25
+
26
+ POSITIVE_EMOTIONS = {
27
+ "admiration","amusement","approval","caring","desire","excitement",
28
+ "gratitude","joy","love","optimism","pride","relief"
29
+ }
30
+
31
+ NEGATIVE_EMOTIONS = {
32
+ "anger","annoyance","disappointment","disapproval","disgust","embarrassment",
33
+ "fear","grief","nervousness","remorse","sadness"
34
+ }
35
+
36
+ NEUTRAL_EMOTIONS = {
37
+ "confusion","curiosity","realization","surprise","neutral"
38
+ }
39
+
40
+
41
+ # ===============================
42
+ # Text Utils
43
+ # ===============================
44
+ def simple_tokenize(text):
45
+ return text.split()
46
+
47
+ def clean_text(text):
48
+ text = text.lower()
49
+ text = re.sub(r'[^a-z0-9\s]', ' ', text)
50
+ text = re.sub(r'\s+', ' ', text).strip()
51
+ return text
52
+
53
+
54
+ # ===============================
55
+ # Model Definition
56
+ # ===============================
57
+ class GoEmotionsLSTM(nn.Module):
58
+ def __init__(self, vocab_size, embed_dim=200, hidden_dim=256, num_classes=28, num_layers=2):
59
+ super().__init__()
60
+
61
+ self.embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
62
+
63
+ self.lstm = nn.LSTM(
64
+ input_size=embed_dim,
65
+ hidden_size=hidden_dim,
66
+ num_layers=num_layers,
67
+ batch_first=True,
68
+ dropout=0.2,
69
+ bidirectional=True
70
+ )
71
+
72
+ self.fc = nn.Linear(hidden_dim * 2, num_classes)
73
+
74
+ def forward(self, x):
75
+ x = self.embeddings(x)
76
+ _, (h, _) = self.lstm(x)
77
+
78
+ h_forward = h[-2]
79
+ h_backward = h[-1]
80
+
81
+ h_cat = torch.cat((h_forward, h_backward), dim=1)
82
+ out = self.fc(h_cat)
83
+
84
+ return out
85
+
86
+
87
+ # ===============================
88
+ # Globals (Loaded Once)
89
+ # ===============================
90
+ model = None
91
+ vocab = None
92
+ max_len = None
93
+
94
+
95
+ # ===============================
96
+ # Load Model at Startup
97
+ # ===============================
98
+ @app.on_event("startup")
99
+ def load_model():
100
+ global model, vocab, max_len
101
+
102
+ print("Loading GoEmotions BiLSTM model...")
103
+
104
+ checkpoint = torch.load("goemotions_bilstm_checkpoint.pth", map_location=DEVICE)
105
+
106
+ vocab = checkpoint["vocab"]
107
+ max_len = checkpoint["max_len"]
108
+
109
+ model = GoEmotionsLSTM(vocab_size=len(vocab))
110
+ model.load_state_dict(checkpoint["model_state"])
111
+ model.to(DEVICE)
112
+ model.eval()
113
+
114
+ print("Model loaded successfully.")
115
+
116
+
117
+ # ===============================
118
+ # Request Schema
119
+ # ===============================
120
+ class PredictRequest(BaseModel):
121
+ text: str
122
+
123
+
124
+ # ===============================
125
+ # Status Endpoint
126
+ # ===============================
127
+ @app.get("/status")
128
+ def status():
129
+ if model is None:
130
+ return {"status": "loading"}
131
+ return {"status": "ok", "model_loaded": True}
132
+
133
+
134
+ # ===============================
135
+ # Sentiment Aggregation Logic
136
+ # ===============================
137
+ def aggregate_sentiment(probs):
138
+ pos_score = 0.0
139
+ neg_score = 0.0
140
+ neu_score = 0.0
141
+
142
+ for i, p in enumerate(probs):
143
+ emotion = emotion_map[i]
144
+ if emotion in POSITIVE_EMOTIONS:
145
+ pos_score += p
146
+ elif emotion in NEGATIVE_EMOTIONS:
147
+ neg_score += p
148
+ else:
149
+ neu_score += p
150
+
151
+ if pos_score > neg_score and pos_score > neu_score:
152
+ return "Positive", pos_score
153
+ elif neg_score > pos_score and neg_score > neu_score:
154
+ return "Negative", neg_score
155
+ else:
156
+ return "Neutral", neu_score
157
+
158
+
159
+ # ===============================
160
+ # Prediction Endpoint
161
+ # ===============================
162
+ @app.post("/predict")
163
+ def predict(req: PredictRequest):
164
+ text = clean_text(req.text)
165
+ tokens = simple_tokenize(text)
166
+
167
+ # Convert tokens to indices
168
+ seq = [vocab.get(tok, 1) for tok in tokens] # <UNK> = 1
169
+
170
+ # Pad / truncate
171
+ if len(seq) < max_len:
172
+ seq += [vocab["<PAD>"]] * (max_len - len(seq))
173
+ else:
174
+ seq = seq[:max_len]
175
+
176
+ x = torch.tensor([seq], dtype=torch.long).to(DEVICE)
177
+
178
+ with torch.no_grad():
179
+ logits = model(x)
180
+ probs = torch.sigmoid(logits).squeeze(0).cpu().numpy()
181
+
182
+ sentiment, score = aggregate_sentiment(probs)
183
+
184
+ return {
185
+ "sentiment": sentiment,
186
+ "confidence": round(float(score) * 100, 2)
187
+ }
goemotions_bilstm_checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2061697f00e13a048b56bfd5b8ce721ba5cdd91143ce5a1d4e1e6a272ff7944d
3
+ size 16386991