andykr1k commited on
Commit
b3e4edb
·
1 Parent(s): 43e4bab

changed to user id

Browse files
Files changed (1) hide show
  1. app.py +35 -110
app.py CHANGED
@@ -18,7 +18,6 @@ load_dotenv()
18
 
19
  app = FastAPI()
20
 
21
- # Enable CORS for all origins (adjust as needed)
22
  app.add_middleware(
23
  CORSMiddleware,
24
  allow_origins=["*"],
@@ -34,17 +33,15 @@ torch.manual_seed(SEED)
34
  if torch.cuda.is_available():
35
  torch.cuda.manual_seed_all(SEED)
36
 
37
- # Global variables for our GNN-based post recommender
38
  global G, features, user_nodes, post_nodes, node2idx, pyg_data, trained_model
39
- G = None # Bipartite graph (users and posts)
40
- features = None # Node features (we use identity)
41
- user_nodes = None # Sorted list of user node IDs
42
- post_nodes = None # Sorted list of post node IDs
43
- node2idx = None # Mapping from node ID to index (for features)
44
- pyg_data = None # PyTorch Geometric data object
45
- trained_model = None # Trained GNN model
46
-
47
- SUPABASE_ID = os.getenv('supabaseID')
48
  SUPABASE_URL = os.getenv('supabaseUrl')
49
  SUPABASE_KEY = os.getenv('supabaseAnonKey')
50
 
@@ -52,73 +49,50 @@ def get_supabase_client():
52
  return create_client(SUPABASE_URL, SUPABASE_KEY)
53
 
54
  def load_and_preprocess_data_for_posts():
55
- """
56
- Build a bipartite directed graph from Supabase data:
57
- - Users: derived from profiles (via posts and likes)
58
- - Posts: from the posts table.
59
- Edges:
60
- - From user to post if the user created the post.
61
- - From user to post if the user liked the post.
62
- """
63
  supabase = get_supabase_client()
64
-
65
- # Load profiles (users)
66
- profiles_response = supabase.table('profiles').select('id, username').execute()
67
  df_profiles = pd.DataFrame(profiles_response.data)
68
- # Create mapping from user id to username
69
- uuid_to_username = dict(zip(df_profiles['id'], df_profiles['username']))
70
-
71
- # Load posts (each with an author)
72
  posts_response = supabase.table('posts').select('id, author').execute()
73
  df_posts = pd.DataFrame(posts_response.data)
74
- # Map post authors to usernames
75
- df_posts['username'] = df_posts['author'].map(uuid_to_username)
76
-
77
- # Load likes: records of (user_id, post_id)
78
  likes_response = supabase.table('likes').select('user_id, post_id').execute()
79
  df_likes = pd.DataFrame(likes_response.data)
80
- df_likes['username'] = df_likes['user_id'].map(uuid_to_username)
81
-
82
- # Build bipartite graph (directed: from user to post)
83
  bipartite = nx.DiGraph()
84
-
85
- # Determine set of users (only those who appear in posts or likes)
86
- user_set = set(df_posts['username'].dropna().tolist()) | set(df_likes['username'].dropna().tolist())
87
- # Determine set of posts (by id)
88
  post_set = set(df_posts['id'].tolist())
89
-
90
- # Add user nodes with attribute type 'user'
91
  for user in user_set:
92
- if user: # ensure non-empty
93
  bipartite.add_node(user, type='user')
94
- # Add post nodes with attribute type 'post'
95
  for post in post_set:
96
  bipartite.add_node(post, type='post')
97
-
98
- # Add edges from post creation: user -> post
99
  for _, row in df_posts.iterrows():
100
- user = row['username']
101
  post = row['id']
102
  if user and post:
103
  bipartite.add_edge(user, post)
104
-
105
- # Add edges from likes: user -> post
106
  for _, row in df_likes.iterrows():
107
- user = row['username']
108
  post = row['post_id']
109
  if user and post:
110
  bipartite.add_edge(user, post)
111
-
112
  return bipartite
113
 
114
- # GNN Model using GraphSAGE
115
  class GraphRecommender(nn.Module):
116
  def __init__(self, input_dim, hidden_dim=128, output_dim=64):
117
  super().__init__()
118
  self.conv1 = SAGEConv(input_dim, hidden_dim)
119
  self.conv2 = SAGEConv(hidden_dim, output_dim)
120
  self.dropout = nn.Dropout(0.3)
121
-
122
  def forward(self, x, edge_index):
123
  x = F.relu(self.conv1(x, edge_index))
124
  x = self.dropout(x)
@@ -126,25 +100,16 @@ class GraphRecommender(nn.Module):
126
  return x
127
 
128
  def prepare_training_data(G, node2idx, user_nodes, post_nodes):
129
- """
130
- Create positive edges for training.
131
- Only consider edges from a user node to a post node.
132
- """
133
- pos_edges = []
134
- for u, v in G.edges():
135
- # Only include if u is a user and v is a post
136
- if G.nodes[u].get('type') == 'user' and G.nodes[v].get('type') == 'post':
137
- pos_edges.append((node2idx[u], node2idx[v]))
138
- pos_edge_index = torch.tensor(pos_edges).T # shape: [2, num_pos_edges]
139
-
140
- # For negative sampling, form all possible user->post pairs and subtract positive edges.
141
  all_possible = [(node2idx[u], node2idx[p]) for u in user_nodes for p in post_nodes]
142
  pos_set = set(pos_edges)
143
  neg_candidates = [pair for pair in all_possible if pair not in pos_set]
144
- # Sample as many negatives as positives (if available)
145
  neg_sample_size = min(len(pos_edges), len(neg_candidates))
146
  neg_edges = random.sample(neg_candidates, neg_sample_size)
147
  neg_edge_index = torch.tensor(neg_edges).T
 
148
  return pos_edge_index, neg_edge_index
149
 
150
  def train_model(model, data, pos_edges, neg_edges, epochs=200):
@@ -158,7 +123,6 @@ def train_model(model, data, pos_edges, neg_edges, epochs=200):
158
 
159
  embeddings = model(data.x, data.edge_index)
160
 
161
- # Compute scores for positive and negative edges via dot product
162
  pos_scores = (embeddings[pos_edges[0]] * embeddings[pos_edges[1]]).sum(1)
163
  neg_scores = (embeddings[neg_edges[0]] * embeddings[neg_edges[1]]).sum(1)
164
 
@@ -182,77 +146,38 @@ def train_model(model, data, pos_edges, neg_edges, epochs=200):
182
  return model
183
 
184
  def rebuild_model():
185
- """
186
- Loads the bipartite user-post graph, computes node features,
187
- prepares training data, trains the GNN model, and updates globals.
188
- """
189
  global G, features, user_nodes, post_nodes, node2idx, pyg_data, trained_model
190
  G = load_and_preprocess_data_for_posts()
191
-
192
- # Get sorted lists of user and post nodes
193
- user_nodes = sorted([n for n, attr in G.nodes(data=True) if attr.get('type') == 'user'])
194
- post_nodes = sorted([n for n, attr in G.nodes(data=True) if attr.get('type') == 'post'])
195
  user_nodes = sorted(n for n, attr in G.nodes(data=True) if attr.get('type') == 'user')
196
  post_nodes = sorted(n for n, attr in G.nodes(data=True) if attr.get('type') == 'post')
197
  all_nodes = user_nodes + post_nodes
198
  node2idx = {node: i for i, node in enumerate(all_nodes)}
199
-
200
- # Use identity features (one-hot) for all nodes
201
  features = torch.eye(len(all_nodes))
202
  pyg_data = from_networkx(G)
203
  pyg_data.x = features
204
-
205
  pos_edge_index, neg_edge_index = prepare_training_data(G, node2idx, user_nodes, post_nodes)
206
-
207
  input_dim = features.shape[1]
208
- model = GraphRecommender(input_dim=input_dim, hidden_dim=128, output_dim=64)
209
  trained_model = train_model(model, pyg_data, pos_edge_index, neg_edge_index)
210
 
211
- def get_recommendations(username, model, data, G, user_nodes, post_nodes, node2idx, top_k=10):
212
- """
213
- For a given username, compute the user's embedding and rank candidate posts (that the user hasn't interacted with).
214
- """
215
- if username not in user_nodes:
216
- return []
217
- user_idx = node2idx[username]
218
-
219
- # Find posts the user already interacted with (edges from username)
220
- user_interacted = set()
221
- for _, v in G.out_edges(username):
222
- if G.nodes[v].get('type') == 'post':
223
- user_interacted.add(v)
224
-
225
- with torch.no_grad():
226
- embeddings = model(data.x, data.edge_index)
227
- user_embed = embeddings[user_idx]
228
-
229
- candidate_scores = []
230
- for post in post_nodes:
231
- if post in user_interacted:
232
- continue
233
- post_idx = node2idx[post]
234
- score = torch.dot(user_embed, embeddings[post_idx]).item()
235
- candidate_scores.append((post, score))
236
- candidate_scores = sorted(candidate_scores, key=lambda x: x[1], reverse=True)
237
- top_posts = [post for post, score in candidate_scores[:top_k]]
238
- return top_posts
239
-
240
- # Endpoints
241
  @app.post("/rebuild")
242
  async def rebuild_handler():
243
  rebuild_model()
244
  return {"status": "success", "message": "Model and data rebuilt successfully"}
245
 
246
  @app.get("/recommend/feed")
247
- async def get_recommendations_handler(username: str = Query(...)):
248
  if trained_model is None:
249
  raise HTTPException(status_code=500, detail="Model not initialized, please rebuild first.")
250
- recs = get_recommendations(username, trained_model, pyg_data, G, user_nodes, post_nodes, node2idx)
251
  return {"status": "success", "recommendations": recs}
252
 
253
  @app.get("/")
254
  async def health_check():
255
  return {"status": "success", "message": "Recommendation service operational"}
256
 
257
- # Optionally, rebuild the model on startup
258
  rebuild_model()
 
18
 
19
  app = FastAPI()
20
 
 
21
  app.add_middleware(
22
  CORSMiddleware,
23
  allow_origins=["*"],
 
33
  if torch.cuda.is_available():
34
  torch.cuda.manual_seed_all(SEED)
35
 
 
36
  global G, features, user_nodes, post_nodes, node2idx, pyg_data, trained_model
37
+ G = None
38
+ features = None
39
+ user_nodes = None
40
+ post_nodes = None
41
+ node2idx = None
42
+ pyg_data = None
43
+ trained_model = None
44
+
 
45
  SUPABASE_URL = os.getenv('supabaseUrl')
46
  SUPABASE_KEY = os.getenv('supabaseAnonKey')
47
 
 
49
  return create_client(SUPABASE_URL, SUPABASE_KEY)
50
 
51
  def load_and_preprocess_data_for_posts():
 
 
 
 
 
 
 
 
52
  supabase = get_supabase_client()
53
+
54
+ profiles_response = supabase.table('profiles').select('id').execute()
 
55
  df_profiles = pd.DataFrame(profiles_response.data)
56
+
 
 
 
57
  posts_response = supabase.table('posts').select('id, author').execute()
58
  df_posts = pd.DataFrame(posts_response.data)
59
+
 
 
 
60
  likes_response = supabase.table('likes').select('user_id, post_id').execute()
61
  df_likes = pd.DataFrame(likes_response.data)
62
+
 
 
63
  bipartite = nx.DiGraph()
64
+
65
+ user_set = set(df_posts['author'].dropna().tolist()) | set(df_likes['user_id'].dropna().tolist())
 
 
66
  post_set = set(df_posts['id'].tolist())
67
+
 
68
  for user in user_set:
69
+ if user:
70
  bipartite.add_node(user, type='user')
71
+
72
  for post in post_set:
73
  bipartite.add_node(post, type='post')
74
+
 
75
  for _, row in df_posts.iterrows():
76
+ user = row['author']
77
  post = row['id']
78
  if user and post:
79
  bipartite.add_edge(user, post)
80
+
 
81
  for _, row in df_likes.iterrows():
82
+ user = row['user_id']
83
  post = row['post_id']
84
  if user and post:
85
  bipartite.add_edge(user, post)
86
+
87
  return bipartite
88
 
 
89
  class GraphRecommender(nn.Module):
90
  def __init__(self, input_dim, hidden_dim=128, output_dim=64):
91
  super().__init__()
92
  self.conv1 = SAGEConv(input_dim, hidden_dim)
93
  self.conv2 = SAGEConv(hidden_dim, output_dim)
94
  self.dropout = nn.Dropout(0.3)
95
+
96
  def forward(self, x, edge_index):
97
  x = F.relu(self.conv1(x, edge_index))
98
  x = self.dropout(x)
 
100
  return x
101
 
102
  def prepare_training_data(G, node2idx, user_nodes, post_nodes):
103
+ pos_edges = [(node2idx[u], node2idx[v]) for u, v in G.edges() if G.nodes[u]['type'] == 'user' and G.nodes[v]['type'] == 'post']
104
+ pos_edge_index = torch.tensor(pos_edges).T
105
+
 
 
 
 
 
 
 
 
 
106
  all_possible = [(node2idx[u], node2idx[p]) for u in user_nodes for p in post_nodes]
107
  pos_set = set(pos_edges)
108
  neg_candidates = [pair for pair in all_possible if pair not in pos_set]
 
109
  neg_sample_size = min(len(pos_edges), len(neg_candidates))
110
  neg_edges = random.sample(neg_candidates, neg_sample_size)
111
  neg_edge_index = torch.tensor(neg_edges).T
112
+
113
  return pos_edge_index, neg_edge_index
114
 
115
  def train_model(model, data, pos_edges, neg_edges, epochs=200):
 
123
 
124
  embeddings = model(data.x, data.edge_index)
125
 
 
126
  pos_scores = (embeddings[pos_edges[0]] * embeddings[pos_edges[1]]).sum(1)
127
  neg_scores = (embeddings[neg_edges[0]] * embeddings[neg_edges[1]]).sum(1)
128
 
 
146
  return model
147
 
148
  def rebuild_model():
 
 
 
 
149
  global G, features, user_nodes, post_nodes, node2idx, pyg_data, trained_model
150
  G = load_and_preprocess_data_for_posts()
151
+
 
 
 
152
  user_nodes = sorted(n for n, attr in G.nodes(data=True) if attr.get('type') == 'user')
153
  post_nodes = sorted(n for n, attr in G.nodes(data=True) if attr.get('type') == 'post')
154
  all_nodes = user_nodes + post_nodes
155
  node2idx = {node: i for i, node in enumerate(all_nodes)}
156
+
 
157
  features = torch.eye(len(all_nodes))
158
  pyg_data = from_networkx(G)
159
  pyg_data.x = features
160
+
161
  pos_edge_index, neg_edge_index = prepare_training_data(G, node2idx, user_nodes, post_nodes)
162
+
163
  input_dim = features.shape[1]
164
+ model = GraphRecommender(input_dim=input_dim)
165
  trained_model = train_model(model, pyg_data, pos_edge_index, neg_edge_index)
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  @app.post("/rebuild")
168
  async def rebuild_handler():
169
  rebuild_model()
170
  return {"status": "success", "message": "Model and data rebuilt successfully"}
171
 
172
  @app.get("/recommend/feed")
173
+ async def get_recommendations_handler(user_id: str = Query(...)):
174
  if trained_model is None:
175
  raise HTTPException(status_code=500, detail="Model not initialized, please rebuild first.")
176
+ recs = get_recommendations(user_id, trained_model, pyg_data, G, user_nodes, post_nodes, node2idx)
177
  return {"status": "success", "recommendations": recs}
178
 
179
  @app.get("/")
180
  async def health_check():
181
  return {"status": "success", "message": "Recommendation service operational"}
182
 
 
183
  rebuild_model()