andykr1k commited on
Commit
d2b6dbf
·
1 Parent(s): d800d2b

Changed output

Browse files
Files changed (1) hide show
  1. app.py +20 -3
app.py CHANGED
@@ -149,7 +149,19 @@ def get_recommendations(user_id, model, data, G, user_nodes, post_nodes, node2id
149
  scores = [(post, torch.dot(user_embed, embeddings[node2idx[post]]).item()) for post in post_nodes if post not in user_interacted]
150
  scores = sorted(scores, key=lambda x: x[1], reverse=True)
151
 
152
- return [post for post, _ in scores[:top_k]]
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  @app.post("/rebuild")
155
  async def rebuild_handler():
@@ -161,8 +173,13 @@ async def get_recommendations_handler(user_id: str = Query(...)):
161
  if trained_model is None:
162
  raise HTTPException(status_code=500, detail="Model not initialized, please rebuild first.")
163
 
164
- recs = get_recommendations(user_id, trained_model, pyg_data, G, user_nodes, post_nodes, node2idx)
165
- return {"status": "success", "recommendations": recs}
 
 
 
 
 
166
 
167
  @app.get("/")
168
  async def health_check():
 
149
  scores = [(post, torch.dot(user_embed, embeddings[node2idx[post]]).item()) for post in post_nodes if post not in user_interacted]
150
  scores = sorted(scores, key=lambda x: x[1], reverse=True)
151
 
152
+ # Return the top_k post IDs
153
+ recommended_post_ids = [post for post, _ in scores[:top_k]]
154
+ return recommended_post_ids
155
+
156
+ def fetch_full_post_records(post_ids):
157
+ """Fetch full post records from Supabase for the given post IDs."""
158
+ supabase = get_supabase_client()
159
+ if not post_ids:
160
+ return []
161
+
162
+ # Query Supabase for all columns of the posts table where id is in the list of recommended post_ids
163
+ response = supabase.table('posts').select('*').in_('id', post_ids).execute()
164
+ return response.data
165
 
166
  @app.post("/rebuild")
167
  async def rebuild_handler():
 
173
  if trained_model is None:
174
  raise HTTPException(status_code=500, detail="Model not initialized, please rebuild first.")
175
 
176
+ # Get recommended post IDs
177
+ recommended_post_ids = get_recommendations(user_id, trained_model, pyg_data, G, user_nodes, post_nodes, node2idx)
178
+
179
+ # Fetch full post records for the recommended post IDs
180
+ full_post_records = fetch_full_post_records(recommended_post_ids)
181
+
182
+ return {"status": "success", "recommendations": full_post_records}
183
 
184
  @app.get("/")
185
  async def health_check():