andykr1k commited on
Commit
af79f6c
·
1 Parent(s): 41305e5

added scheduler, logging and optimization updates

Browse files
Files changed (2) hide show
  1. app.py +130 -54
  2. requirements.txt +2 -1
app.py CHANGED
@@ -12,7 +12,17 @@ from torch_geometric.nn import SAGEConv
12
  from supabase import create_client
13
  from fastapi import FastAPI, HTTPException, Query
14
  from fastapi.middleware.cors import CORSMiddleware
 
15
  from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
16
 
17
  load_dotenv()
18
 
@@ -46,24 +56,32 @@ def get_supabase_client():
46
  def load_and_preprocess_data():
47
  supabase = get_supabase_client()
48
 
49
- profiles = pd.DataFrame(supabase.table('profiles').select('id').execute().data)
50
- posts = pd.DataFrame(supabase.table('posts').select('id, author').execute().data)
51
- likes = pd.DataFrame(supabase.table('likes').select('user_id, post_id').execute().data)
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  bipartite = nx.DiGraph()
54
 
55
- user_set = set(posts['author']) | set(likes['user_id'])
56
- post_set = set(posts['id'])
57
 
58
- for user in user_set:
59
- bipartite.add_node(user, type='user')
60
- for post in post_set:
61
- bipartite.add_node(post, type='post')
62
 
63
- for _, row in posts.iterrows():
64
- bipartite.add_edge(row['author'], row['id'])
65
- for _, row in likes.iterrows():
66
- bipartite.add_edge(row['user_id'], row['post_id'])
67
 
68
  return bipartite
69
 
@@ -116,54 +134,73 @@ def train_model(model, data, pos_edges, neg_edges, epochs=200):
116
 
117
  def rebuild_model():
118
  global G, features, user_nodes, post_nodes, node2idx, pyg_data, trained_model
119
-
120
- G = load_and_preprocess_data()
121
- user_nodes = sorted(n for n, attr in G.nodes(data=True) if attr['type'] == 'user')
122
- post_nodes = sorted(n for n, attr in G.nodes(data=True) if attr['type'] == 'post')
123
-
124
- all_nodes = user_nodes + post_nodes
125
- node2idx = {node: i for i, node in enumerate(all_nodes)}
126
-
127
- features = torch.eye(len(all_nodes))
128
- pyg_data = from_networkx(G)
129
- pyg_data.x = features
130
-
131
- pos_edges, neg_edges = prepare_training_data(G, node2idx, user_nodes, post_nodes)
132
-
133
- input_dim = features.shape[1]
134
- model = GraphRecommender(input_dim)
135
- trained_model = train_model(model, pyg_data, pos_edges, neg_edges)
136
-
137
- def get_recommendations(user_id, model, data, G, user_nodes, post_nodes, node2idx, top_k=10):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  if user_id not in user_nodes:
139
  return []
140
 
141
  user_idx = node2idx[user_id]
142
-
143
  user_interacted = {v for _, v in G.out_edges(user_id) if G.nodes[v]['type'] == 'post'}
 
144
 
145
  with torch.no_grad():
146
  embeddings = model(data.x, data.edge_index)
147
- user_embed = embeddings[user_idx]
 
148
 
149
- scores = [(post, torch.dot(user_embed, embeddings[node2idx[post]]).item()) for post in post_nodes if post not in user_interacted]
150
- scores = sorted(scores, key=lambda x: x[1], reverse=True)
151
 
152
- # Return the top_k post IDs
153
- recommended_post_ids = [post for post, _ in scores[:top_k]]
154
- return recommended_post_ids
155
 
156
- def fetch_full_post_records(post_ids):
157
- """Fetch full post records from Supabase for the given post IDs."""
 
158
  supabase = get_supabase_client()
159
  if not post_ids:
160
  return []
161
 
162
- response = supabase.table('posts').select('*').in_('id', post_ids).execute()
163
- records = response.data
164
-
165
- for record in records:
166
- record['type'] = 'post'
 
 
 
167
 
168
  return records
169
 
@@ -177,16 +214,55 @@ async def get_recommendations_handler(user_id: str = Query(...)):
177
  if trained_model is None:
178
  raise HTTPException(status_code=500, detail="Model not initialized, please rebuild first.")
179
 
180
- # Get recommended post IDs
181
- recommended_post_ids = get_recommendations(user_id, trained_model, pyg_data, G, user_nodes, post_nodes, node2idx)
182
-
183
- # Fetch full post records for the recommended post IDs
184
- full_post_records = fetch_full_post_records(recommended_post_ids)
185
-
186
- return {"status": "success", "recommendations": full_post_records}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  @app.get("/")
189
  async def health_check():
190
  return {"status": "success", "message": "Service operational"}
191
 
192
- rebuild_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  from supabase import create_client
13
  from fastapi import FastAPI, HTTPException, Query
14
  from fastapi.middleware.cors import CORSMiddleware
15
+ from fastapi.responses import StreamingResponse
16
  from dotenv import load_dotenv
17
+ import json
18
+ from apscheduler.schedulers.background import BackgroundScheduler
19
+ from apscheduler.triggers.cron import CronTrigger
20
+ import logging
21
+ import uvicorn
22
+
23
+ # Configure logging
24
+ logging.basicConfig(level=logging.INFO)
25
+ logger = logging.getLogger(__name__)
26
 
27
  load_dotenv()
28
 
 
56
  def load_and_preprocess_data():
57
  supabase = get_supabase_client()
58
 
59
+ def fetch_table(table, columns, chunk_size=1000):
60
+ offset = 0
61
+ all_data = []
62
+ while True:
63
+ response = supabase.table(table).select(columns).range(offset, offset + chunk_size - 1).execute()
64
+ data = response.data
65
+ if not data:
66
+ break
67
+ all_data.extend(data)
68
+ offset += chunk_size
69
+ return all_data
70
+
71
+ profiles = fetch_table('profiles', 'id')
72
+ posts = fetch_table('posts', 'id, author')
73
+ likes = fetch_table('likes', 'user_id, post_id')
74
 
75
  bipartite = nx.DiGraph()
76
 
77
+ user_set = {p['author'] for p in posts} | {l['user_id'] for l in likes}
78
+ post_set = {p['id'] for p in posts}
79
 
80
+ bipartite.add_nodes_from(user_set, type='user')
81
+ bipartite.add_nodes_from(post_set, type='post')
 
 
82
 
83
+ bipartite.add_edges_from((p['author'], p['id']) for p in posts)
84
+ bipartite.add_edges_from((l['user_id'], l['post_id']) for l in likes)
 
 
85
 
86
  return bipartite
87
 
 
134
 
135
  def rebuild_model():
136
  global G, features, user_nodes, post_nodes, node2idx, pyg_data, trained_model
137
+ logger.info("Starting model rebuild at 3:30 AM")
138
+ try:
139
+ G = load_and_preprocess_data()
140
+ user_nodes = sorted(n for n, attr in G.nodes(data=True) if attr['type'] == 'user')
141
+ post_nodes = sorted(n for n, attr in G.nodes(data=True) if attr['type'] == 'post')
142
+
143
+ all_nodes = user_nodes + post_nodes
144
+ node2idx = {node: i for i, node in enumerate(all_nodes)}
145
+
146
+ features = torch.sparse_coo_tensor(
147
+ torch.arange(len(all_nodes)).repeat(2, 1),
148
+ torch.ones(len(all_nodes)),
149
+ (len(all_nodes), len(all_nodes))
150
+ )
151
+ pyg_data = from_networkx(G)
152
+ pyg_data.x = features
153
+
154
+ pos_edges, neg_edges = prepare_training_data(G, node2idx, user_nodes, post_nodes)
155
+
156
+ input_dim = features.shape[1]
157
+ model = GraphRecommender(input_dim)
158
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
159
+ model = model.to(device)
160
+ pyg_data = pyg_data.to(device)
161
+ pos_edges = pos_edges.to(device)
162
+ neg_edges = neg_edges.to(device)
163
+
164
+ trained_model = train_model(model, pyg_data, pos_edges, neg_edges)
165
+ trained_model = trained_model.to('cpu')
166
+ logger.info("Model rebuild completed successfully")
167
+ except Exception as e:
168
+ logger.error(f"Error during model rebuild: {str(e)}")
169
+ raise
170
+
171
+ def get_recommendations(user_id, model, data, G, user_nodes, post_nodes, node2idx):
172
  if user_id not in user_nodes:
173
  return []
174
 
175
  user_idx = node2idx[user_id]
 
176
  user_interacted = {v for _, v in G.out_edges(user_id) if G.nodes[v]['type'] == 'post'}
177
+ post_indices = [node2idx[p] for p in post_nodes if p not in user_interacted]
178
 
179
  with torch.no_grad():
180
  embeddings = model(data.x, data.edge_index)
181
+ user_embed = embeddings[user_idx].unsqueeze(0)
182
+ post_embeds = embeddings[post_indices]
183
 
184
+ scores = torch.matmul(user_embed, post_embeds.T).squeeze(0)
 
185
 
186
+ post_scores = [(post_nodes[i], score.item()) for i, score in zip(post_indices, scores)]
187
+ post_scores = sorted(post_scores, key=lambda x: x[1], reverse=True)
 
188
 
189
+ return [{"post_id": post, "score": score} for post, score in post_scores]
190
+
191
+ def fetch_full_post_records(post_ids, batch_size=1000):
192
  supabase = get_supabase_client()
193
  if not post_ids:
194
  return []
195
 
196
+ records = []
197
+ for i in range(0, len(post_ids), batch_size):
198
+ batch_ids = post_ids[i:i + batch_size]
199
+ response = supabase.table('posts').select('*').in_('id', batch_ids).execute()
200
+ batch_records = response.data
201
+ for record in batch_records:
202
+ record['type'] = 'post'
203
+ records.extend(batch_records)
204
 
205
  return records
206
 
 
214
  if trained_model is None:
215
  raise HTTPException(status_code=500, detail="Model not initialized, please rebuild first.")
216
 
217
+ recommended_posts = get_recommendations(user_id, trained_model, pyg_data, G, user_nodes, post_nodes, node2idx)
218
+ if not recommended_posts:
219
+ return {"status": "success", "recommendations": []}
220
+
221
+ post_ids = [post["post_id"] for post in recommended_posts]
222
+ full_post_records = fetch_full_post_records(post_ids)
223
+
224
+ post_dict = {post["id"]: post for post in full_post_records}
225
+ ordered_recommendations = []
226
+ for post in recommended_posts:
227
+ post_id = post["post_id"]
228
+ if post_id in post_dict:
229
+ post_record = post_dict[post_id]
230
+ post_record["score"] = post["score"]
231
+ ordered_recommendations.append(post_record)
232
+
233
+ def generate():
234
+ yield '{"status": "success", "recommendations": ['
235
+ for i, rec in enumerate(ordered_recommendations):
236
+ yield json.dumps(rec)
237
+ if i < len(ordered_recommendations) - 1:
238
+ yield ','
239
+ yield ']}'
240
+
241
+ return StreamingResponse(generate(), media_type="application/json")
242
 
243
  @app.get("/")
244
  async def health_check():
245
  return {"status": "success", "message": "Service operational"}
246
 
247
+ scheduler = BackgroundScheduler(timezone="PST")
248
+
249
+ scheduler.add_job(
250
+ rebuild_model,
251
+ trigger=CronTrigger(hour=3, minute=30),
252
+ id='daily_model_rebuild',
253
+ replace_existing=True
254
+ )
255
+
256
+ @app.on_event("startup")
257
+ async def startup_event():
258
+ rebuild_model()
259
+ scheduler.start()
260
+ logger.info("Scheduler started, model will rebuild daily at 3:30 AM")
261
+
262
+ @app.on_event("shutdown")
263
+ async def shutdown_event():
264
+ scheduler.shutdown()
265
+ logger.info("Scheduler shut down")
266
+
267
+ if __name__ == "__main__":
268
+ uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt CHANGED
@@ -6,4 +6,5 @@ torch_geometric
6
  supabase
7
  fastapi
8
  python-dotenv
9
- uvicorn
 
 
6
  supabase
7
  fastapi
8
  python-dotenv
9
+ uvicorn
10
+ apscheduler