andykr1k commited on
Commit
d800d2b
·
1 Parent(s): b3e4edb

Fixing small bugs

Browse files
Files changed (1) hide show
  1. app.py +44 -56
app.py CHANGED
@@ -33,14 +33,9 @@ torch.manual_seed(SEED)
33
  if torch.cuda.is_available():
34
  torch.cuda.manual_seed_all(SEED)
35
 
 
36
  global G, features, user_nodes, post_nodes, node2idx, pyg_data, trained_model
37
- G = None
38
- features = None
39
- user_nodes = None
40
- post_nodes = None
41
- node2idx = None
42
- pyg_data = None
43
- trained_model = None
44
 
45
  SUPABASE_URL = os.getenv('supabaseUrl')
46
  SUPABASE_KEY = os.getenv('supabaseAnonKey')
@@ -48,41 +43,27 @@ SUPABASE_KEY = os.getenv('supabaseAnonKey')
48
  def get_supabase_client():
49
  return create_client(SUPABASE_URL, SUPABASE_KEY)
50
 
51
- def load_and_preprocess_data_for_posts():
52
  supabase = get_supabase_client()
53
 
54
- profiles_response = supabase.table('profiles').select('id').execute()
55
- df_profiles = pd.DataFrame(profiles_response.data)
56
-
57
- posts_response = supabase.table('posts').select('id, author').execute()
58
- df_posts = pd.DataFrame(posts_response.data)
59
-
60
- likes_response = supabase.table('likes').select('user_id, post_id').execute()
61
- df_likes = pd.DataFrame(likes_response.data)
62
 
63
  bipartite = nx.DiGraph()
64
 
65
- user_set = set(df_posts['author'].dropna().tolist()) | set(df_likes['user_id'].dropna().tolist())
66
- post_set = set(df_posts['id'].tolist())
67
 
68
  for user in user_set:
69
- if user:
70
- bipartite.add_node(user, type='user')
71
-
72
  for post in post_set:
73
  bipartite.add_node(post, type='post')
74
 
75
- for _, row in df_posts.iterrows():
76
- user = row['author']
77
- post = row['id']
78
- if user and post:
79
- bipartite.add_edge(user, post)
80
-
81
- for _, row in df_likes.iterrows():
82
- user = row['user_id']
83
- post = row['post_id']
84
- if user and post:
85
- bipartite.add_edge(user, post)
86
 
87
  return bipartite
88
 
@@ -101,21 +82,18 @@ class GraphRecommender(nn.Module):
101
 
102
  def prepare_training_data(G, node2idx, user_nodes, post_nodes):
103
  pos_edges = [(node2idx[u], node2idx[v]) for u, v in G.edges() if G.nodes[u]['type'] == 'user' and G.nodes[v]['type'] == 'post']
104
- pos_edge_index = torch.tensor(pos_edges).T
105
 
106
  all_possible = [(node2idx[u], node2idx[p]) for u in user_nodes for p in post_nodes]
107
  pos_set = set(pos_edges)
108
  neg_candidates = [pair for pair in all_possible if pair not in pos_set]
 
109
  neg_sample_size = min(len(pos_edges), len(neg_candidates))
110
  neg_edges = random.sample(neg_candidates, neg_sample_size)
111
- neg_edge_index = torch.tensor(neg_edges).T
112
 
113
- return pos_edge_index, neg_edge_index
114
 
115
  def train_model(model, data, pos_edges, neg_edges, epochs=200):
116
  optimizer = optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-4)
117
- best_loss = float('inf')
118
- patience_counter = 0
119
 
120
  for epoch in range(epochs):
121
  model.train()
@@ -128,29 +106,21 @@ def train_model(model, data, pos_edges, neg_edges, epochs=200):
128
 
129
  pos_loss = F.binary_cross_entropy_with_logits(pos_scores, torch.ones_like(pos_scores))
130
  neg_loss = F.binary_cross_entropy_with_logits(neg_scores, torch.zeros_like(neg_scores))
131
- reg_loss = torch.norm(embeddings, p=2)
132
 
133
- total_loss = pos_loss + neg_loss + 0.001 * reg_loss
134
 
135
  total_loss.backward()
136
  optimizer.step()
137
 
138
- if total_loss < best_loss:
139
- best_loss = total_loss
140
- patience_counter = 0
141
- else:
142
- patience_counter += 1
143
- if patience_counter >= 20:
144
- break
145
-
146
  return model
147
 
148
  def rebuild_model():
149
  global G, features, user_nodes, post_nodes, node2idx, pyg_data, trained_model
150
- G = load_and_preprocess_data_for_posts()
151
 
152
- user_nodes = sorted(n for n, attr in G.nodes(data=True) if attr.get('type') == 'user')
153
- post_nodes = sorted(n for n, attr in G.nodes(data=True) if attr.get('type') == 'post')
 
 
154
  all_nodes = user_nodes + post_nodes
155
  node2idx = {node: i for i, node in enumerate(all_nodes)}
156
 
@@ -158,26 +128,44 @@ def rebuild_model():
158
  pyg_data = from_networkx(G)
159
  pyg_data.x = features
160
 
161
- pos_edge_index, neg_edge_index = prepare_training_data(G, node2idx, user_nodes, post_nodes)
162
 
163
  input_dim = features.shape[1]
164
- model = GraphRecommender(input_dim=input_dim)
165
- trained_model = train_model(model, pyg_data, pos_edge_index, neg_edge_index)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  @app.post("/rebuild")
168
  async def rebuild_handler():
169
  rebuild_model()
170
- return {"status": "success", "message": "Model and data rebuilt successfully"}
171
 
172
  @app.get("/recommend/feed")
173
  async def get_recommendations_handler(user_id: str = Query(...)):
174
  if trained_model is None:
175
  raise HTTPException(status_code=500, detail="Model not initialized, please rebuild first.")
 
176
  recs = get_recommendations(user_id, trained_model, pyg_data, G, user_nodes, post_nodes, node2idx)
177
  return {"status": "success", "recommendations": recs}
178
 
179
  @app.get("/")
180
  async def health_check():
181
- return {"status": "success", "message": "Recommendation service operational"}
182
 
183
- rebuild_model()
 
33
  if torch.cuda.is_available():
34
  torch.cuda.manual_seed_all(SEED)
35
 
36
+ # Global Variables
37
  global G, features, user_nodes, post_nodes, node2idx, pyg_data, trained_model
38
+ G = features = user_nodes = post_nodes = node2idx = pyg_data = trained_model = None
 
 
 
 
 
 
39
 
40
  SUPABASE_URL = os.getenv('supabaseUrl')
41
  SUPABASE_KEY = os.getenv('supabaseAnonKey')
 
43
  def get_supabase_client():
44
  return create_client(SUPABASE_URL, SUPABASE_KEY)
45
 
46
+ def load_and_preprocess_data():
47
  supabase = get_supabase_client()
48
 
49
+ profiles = pd.DataFrame(supabase.table('profiles').select('id').execute().data)
50
+ posts = pd.DataFrame(supabase.table('posts').select('id, author').execute().data)
51
+ likes = pd.DataFrame(supabase.table('likes').select('user_id, post_id').execute().data)
 
 
 
 
 
52
 
53
  bipartite = nx.DiGraph()
54
 
55
+ user_set = set(posts['author']) | set(likes['user_id'])
56
+ post_set = set(posts['id'])
57
 
58
  for user in user_set:
59
+ bipartite.add_node(user, type='user')
 
 
60
  for post in post_set:
61
  bipartite.add_node(post, type='post')
62
 
63
+ for _, row in posts.iterrows():
64
+ bipartite.add_edge(row['author'], row['id'])
65
+ for _, row in likes.iterrows():
66
+ bipartite.add_edge(row['user_id'], row['post_id'])
 
 
 
 
 
 
 
67
 
68
  return bipartite
69
 
 
82
 
83
  def prepare_training_data(G, node2idx, user_nodes, post_nodes):
84
  pos_edges = [(node2idx[u], node2idx[v]) for u, v in G.edges() if G.nodes[u]['type'] == 'user' and G.nodes[v]['type'] == 'post']
 
85
 
86
  all_possible = [(node2idx[u], node2idx[p]) for u in user_nodes for p in post_nodes]
87
  pos_set = set(pos_edges)
88
  neg_candidates = [pair for pair in all_possible if pair not in pos_set]
89
+
90
  neg_sample_size = min(len(pos_edges), len(neg_candidates))
91
  neg_edges = random.sample(neg_candidates, neg_sample_size)
 
92
 
93
+ return torch.tensor(pos_edges).T, torch.tensor(neg_edges).T
94
 
95
  def train_model(model, data, pos_edges, neg_edges, epochs=200):
96
  optimizer = optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-4)
 
 
97
 
98
  for epoch in range(epochs):
99
  model.train()
 
106
 
107
  pos_loss = F.binary_cross_entropy_with_logits(pos_scores, torch.ones_like(pos_scores))
108
  neg_loss = F.binary_cross_entropy_with_logits(neg_scores, torch.zeros_like(neg_scores))
 
109
 
110
+ total_loss = pos_loss + neg_loss
111
 
112
  total_loss.backward()
113
  optimizer.step()
114
 
 
 
 
 
 
 
 
 
115
  return model
116
 
117
  def rebuild_model():
118
  global G, features, user_nodes, post_nodes, node2idx, pyg_data, trained_model
 
119
 
120
+ G = load_and_preprocess_data()
121
+ user_nodes = sorted(n for n, attr in G.nodes(data=True) if attr['type'] == 'user')
122
+ post_nodes = sorted(n for n, attr in G.nodes(data=True) if attr['type'] == 'post')
123
+
124
  all_nodes = user_nodes + post_nodes
125
  node2idx = {node: i for i, node in enumerate(all_nodes)}
126
 
 
128
  pyg_data = from_networkx(G)
129
  pyg_data.x = features
130
 
131
+ pos_edges, neg_edges = prepare_training_data(G, node2idx, user_nodes, post_nodes)
132
 
133
  input_dim = features.shape[1]
134
+ model = GraphRecommender(input_dim)
135
+ trained_model = train_model(model, pyg_data, pos_edges, neg_edges)
136
+
137
+ def get_recommendations(user_id, model, data, G, user_nodes, post_nodes, node2idx, top_k=10):
138
+ if user_id not in user_nodes:
139
+ return []
140
+
141
+ user_idx = node2idx[user_id]
142
+
143
+ user_interacted = {v for _, v in G.out_edges(user_id) if G.nodes[v]['type'] == 'post'}
144
+
145
+ with torch.no_grad():
146
+ embeddings = model(data.x, data.edge_index)
147
+ user_embed = embeddings[user_idx]
148
+
149
+ scores = [(post, torch.dot(user_embed, embeddings[node2idx[post]]).item()) for post in post_nodes if post not in user_interacted]
150
+ scores = sorted(scores, key=lambda x: x[1], reverse=True)
151
+
152
+ return [post for post, _ in scores[:top_k]]
153
 
154
  @app.post("/rebuild")
155
  async def rebuild_handler():
156
  rebuild_model()
157
+ return {"status": "success", "message": "Model rebuilt successfully"}
158
 
159
  @app.get("/recommend/feed")
160
  async def get_recommendations_handler(user_id: str = Query(...)):
161
  if trained_model is None:
162
  raise HTTPException(status_code=500, detail="Model not initialized, please rebuild first.")
163
+
164
  recs = get_recommendations(user_id, trained_model, pyg_data, G, user_nodes, post_nodes, node2idx)
165
  return {"status": "success", "recommendations": recs}
166
 
167
  @app.get("/")
168
  async def health_check():
169
+ return {"status": "success", "message": "Service operational"}
170
 
171
+ rebuild_model()