Spaces:
Sleeping
Sleeping
andykr1k commited on
Commit ·
d2b6dbf
1
Parent(s): d800d2b
Changed output
Browse files
app.py
CHANGED
|
@@ -149,7 +149,19 @@ def get_recommendations(user_id, model, data, G, user_nodes, post_nodes, node2id
|
|
| 149 |
scores = [(post, torch.dot(user_embed, embeddings[node2idx[post]]).item()) for post in post_nodes if post not in user_interacted]
|
| 150 |
scores = sorted(scores, key=lambda x: x[1], reverse=True)
|
| 151 |
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
@app.post("/rebuild")
|
| 155 |
async def rebuild_handler():
|
|
@@ -161,8 +173,13 @@ async def get_recommendations_handler(user_id: str = Query(...)):
|
|
| 161 |
if trained_model is None:
|
| 162 |
raise HTTPException(status_code=500, detail="Model not initialized, please rebuild first.")
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
@app.get("/")
|
| 168 |
async def health_check():
|
|
|
|
| 149 |
scores = [(post, torch.dot(user_embed, embeddings[node2idx[post]]).item()) for post in post_nodes if post not in user_interacted]
|
| 150 |
scores = sorted(scores, key=lambda x: x[1], reverse=True)
|
| 151 |
|
| 152 |
+
# Return the top_k post IDs
|
| 153 |
+
recommended_post_ids = [post for post, _ in scores[:top_k]]
|
| 154 |
+
return recommended_post_ids
|
| 155 |
+
|
| 156 |
+
def fetch_full_post_records(post_ids):
|
| 157 |
+
"""Fetch full post records from Supabase for the given post IDs."""
|
| 158 |
+
supabase = get_supabase_client()
|
| 159 |
+
if not post_ids:
|
| 160 |
+
return []
|
| 161 |
+
|
| 162 |
+
# Query Supabase for all columns of the posts table where id is in the list of recommended post_ids
|
| 163 |
+
response = supabase.table('posts').select('*').in_('id', post_ids).execute()
|
| 164 |
+
return response.data
|
| 165 |
|
| 166 |
@app.post("/rebuild")
|
| 167 |
async def rebuild_handler():
|
|
|
|
| 173 |
if trained_model is None:
|
| 174 |
raise HTTPException(status_code=500, detail="Model not initialized, please rebuild first.")
|
| 175 |
|
| 176 |
+
# Get recommended post IDs
|
| 177 |
+
recommended_post_ids = get_recommendations(user_id, trained_model, pyg_data, G, user_nodes, post_nodes, node2idx)
|
| 178 |
+
|
| 179 |
+
# Fetch full post records for the recommended post IDs
|
| 180 |
+
full_post_records = fetch_full_post_records(recommended_post_ids)
|
| 181 |
+
|
| 182 |
+
return {"status": "success", "recommendations": full_post_records}
|
| 183 |
|
| 184 |
@app.get("/")
|
| 185 |
async def health_check():
|