andykr1k commited on
Commit
d070610
·
1 Parent(s): 828fed3

added scheduler, logging and optimization updates

Browse files
Files changed (1) hide show
  1. app.py +122 -47
app.py CHANGED
@@ -13,13 +13,21 @@ from torch_geometric.nn import SAGEConv
13
  from supabase import create_client
14
  from fastapi import FastAPI, HTTPException, Query
15
  from fastapi.middleware.cors import CORSMiddleware
 
16
  from dotenv import load_dotenv
 
 
 
 
 
 
 
 
17
 
18
  load_dotenv()
19
 
20
  app = FastAPI()
21
 
22
- # Enable CORS
23
  app.add_middleware(
24
  CORSMiddleware,
25
  allow_origins=["*"],
@@ -37,11 +45,7 @@ if torch.cuda.is_available():
37
 
38
  # Global variables
39
  global G, features, user_ids, pyg_data, trained_model
40
- G = None
41
- features = None
42
- user_ids = None
43
- pyg_data = None
44
- trained_model = None
45
 
46
  SUPABASE_URL = os.getenv('supabaseUrl')
47
  SUPABASE_KEY = os.getenv('supabaseAnonKey')
@@ -51,34 +55,60 @@ def get_supabase_client():
51
 
52
  def load_and_preprocess_data():
53
  supabase = get_supabase_client()
54
- followers_response = supabase.table('followers').select('*').execute()
55
- users_response = supabase.table('profiles').select('id').execute()
56
-
57
- followers = pd.DataFrame(followers_response.data)
58
- users = pd.DataFrame(users_response.data)
59
-
60
- # 'id' is the followed user, 'following' is the follower
61
- merged = followers.merge(users, left_on='id', right_on='id', how='left')
62
- merged = merged.rename(columns={'following': 'follower_id', 'id': 'followed_id'})
63
- merged = merged[['follower_id', 'followed_id']].dropna()
64
- return merged[(merged['follower_id'] != '') & (merged['followed_id'] != '')]
65
-
66
- def create_graph_dataframe(merged_df):
67
- # Edge from follower to followed
68
- G = nx.from_pandas_edgelist(merged_df, source='follower_id', target='followed_id', create_using=nx.DiGraph())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  user_ids = sorted(G.nodes())
70
- return G, torch.eye(len(user_ids)), user_ids
 
 
 
 
 
 
 
 
71
 
72
  def prepare_training_data(G, user_ids):
73
  pos_edges = [(user_ids.index(u), user_ids.index(v)) for u, v in G.edges()]
74
- pos_edge_index = torch.tensor(pos_edges).T
75
 
76
  num_nodes = len(user_ids)
77
  all_possible_edges = set(itertools.permutations(range(num_nodes), 2))
78
  existing_edges = set(zip(pos_edge_index[0].tolist(), pos_edge_index[1].tolist()))
79
- negative_edges = random.sample(list(all_possible_edges - existing_edges), len(pos_edges))
 
80
 
81
- return pos_edge_index, torch.tensor(negative_edges).T
 
82
 
83
  class GraphRecommender(nn.Module):
84
  def __init__(self, input_dim, hidden_dim=128, output_dim=64):
@@ -93,11 +123,18 @@ class GraphRecommender(nn.Module):
93
  x = self.conv2(x, edge_index)
94
  return x
95
 
96
- def train_model(model, data, pos_edges, neg_edges, epochs=200):
 
 
 
 
 
 
97
  optimizer = optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-4)
98
  best_loss = float('inf')
99
  patience_counter = 0
100
 
 
101
  for epoch in range(epochs):
102
  model.train()
103
  optimizer.zero_grad()
@@ -121,42 +158,48 @@ def train_model(model, data, pos_edges, neg_edges, epochs=200):
121
  patience_counter = 0
122
  else:
123
  patience_counter += 1
124
- if patience_counter >= 20:
 
125
  break
126
 
127
- return model
 
128
 
129
  def get_recommendations(user_id, model, data, G, user_ids, top_k=10):
130
  if user_id not in user_ids:
131
  return []
132
 
133
  user_idx = user_ids.index(user_id)
134
- # Successors are the users this user follows
135
  current_follows = set(G.successors(user_id))
136
-
137
- # Exclude self and already-followed users
138
- candidates = [u for u in user_ids if u != user_id and u not in current_follows]
139
-
140
  with torch.no_grad():
141
  embeddings = model(data.x, data.edge_index)
142
- user_embed = embeddings[user_idx]
143
- candidate_indices = [user_ids.index(u) for u in candidates]
144
  candidate_embeds = embeddings[candidate_indices]
145
- scores = torch.mm(user_embed.view(1, -1), candidate_embeds.T).squeeze()
146
 
147
  top_indices = scores.argsort(descending=True)[:top_k]
148
- return [candidates[i] for i in top_indices]
 
 
149
 
150
  def rebuild_model():
151
  global G, features, user_ids, pyg_data, trained_model
152
- merged_df = load_and_preprocess_data()
153
- G, features, user_ids = create_graph_dataframe(merged_df)
154
- pyg_data = from_networkx(G)
155
- pyg_data.x = features
156
-
157
- pos_edge_index, neg_edge_index = prepare_training_data(G, user_ids)
158
- model = GraphRecommender(input_dim=len(user_ids))
159
- trained_model = train_model(model, pyg_data, pos_edge_index, neg_edge_index)
 
 
 
 
 
 
160
 
161
  @app.post("/rebuild")
162
  async def rebuild_handler():
@@ -169,10 +212,42 @@ async def get_recommendations_handler(user_id: str = Query(...)):
169
  raise HTTPException(status_code=500, detail="Model not initialized, please rebuild first.")
170
 
171
  recommendations = get_recommendations(user_id, trained_model, pyg_data, G, user_ids)
172
- return {"status": "success", "recommendations": recommendations}
 
 
 
 
 
 
 
 
 
 
173
 
174
  @app.get("/")
175
  async def health_check():
176
  return {"status": "success", "message": "Recommendation service operational"}
177
 
178
- rebuild_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  from supabase import create_client
14
  from fastapi import FastAPI, HTTPException, Query
15
  from fastapi.middleware.cors import CORSMiddleware
16
+ from fastapi.responses import StreamingResponse
17
  from dotenv import load_dotenv
18
+ import json
19
+ from apscheduler.schedulers.background import BackgroundScheduler
20
+ from apscheduler.triggers.cron import CronTrigger
21
+ import logging
22
+
23
+ # Configure logging
24
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
25
+ logger = logging.getLogger(__name__)
26
 
27
  load_dotenv()
28
 
29
  app = FastAPI()
30
 
 
31
  app.add_middleware(
32
  CORSMiddleware,
33
  allow_origins=["*"],
 
45
 
46
  # Global variables
47
  global G, features, user_ids, pyg_data, trained_model
48
+ G = features = user_ids = pyg_data = trained_model = None
 
 
 
 
49
 
50
  SUPABASE_URL = os.getenv('supabaseUrl')
51
  SUPABASE_KEY = os.getenv('supabaseAnonKey')
 
55
 
56
  def load_and_preprocess_data():
57
  supabase = get_supabase_client()
58
+ logger.info("Loading data from Supabase")
59
+
60
+ def fetch_table(table, columns, chunk_size=1000):
61
+ offset = 0
62
+ all_data = []
63
+ while True:
64
+ response = supabase.table(table).select(columns).range(offset, offset + chunk_size - 1).execute()
65
+ data = response.data
66
+ if not data:
67
+ break
68
+ all_data.extend(data)
69
+ offset += chunk_size
70
+ return all_data
71
+
72
+ followers = fetch_table('followers', 'id, following')
73
+ users = fetch_table('profiles', 'id')
74
+
75
+ # Use native Python for merging instead of pandas
76
+ follower_dict = {f['id']: f['following'] for f in followers}
77
+ user_set = {u['id'] for u in users}
78
+ merged = [
79
+ {'follower_id': follower_dict[fid], 'followed_id': fid}
80
+ for fid in follower_dict if fid in user_set and follower_dict[fid] != '' and fid != ''
81
+ ]
82
+ logger.info(f"Loaded {len(merged)} follower relationships")
83
+ return merged
84
+
85
+ def create_graph_dataframe(merged_data):
86
+ G = nx.DiGraph()
87
+ edges = [(d['follower_id'], d['followed_id']) for d in merged_data]
88
+ G.add_edges_from(edges)
89
  user_ids = sorted(G.nodes())
90
+
91
+ # Use sparse identity matrix for features
92
+ features = torch.sparse_coo_tensor(
93
+ torch.arange(len(user_ids)).repeat(2, 1),
94
+ torch.ones(len(user_ids)),
95
+ (len(user_ids), len(user_ids))
96
+ )
97
+ logger.info(f"Created graph with {len(user_ids)} nodes")
98
+ return G, features, user_ids
99
 
100
  def prepare_training_data(G, user_ids):
101
  pos_edges = [(user_ids.index(u), user_ids.index(v)) for u, v in G.edges()]
102
+ pos_edge_index = torch.tensor(pos_edges, dtype=torch.long).T
103
 
104
  num_nodes = len(user_ids)
105
  all_possible_edges = set(itertools.permutations(range(num_nodes), 2))
106
  existing_edges = set(zip(pos_edge_index[0].tolist(), pos_edge_index[1].tolist()))
107
+ neg_sample_size = len(pos_edges)
108
+ negative_edges = random.sample(list(all_possible_edges - existing_edges), neg_sample_size)
109
 
110
+ logger.info(f"Prepared {len(pos_edges)} positive and {len(negative_edges)} negative edges")
111
+ return pos_edge_index, torch.tensor(negative_edges, dtype=torch.long).T
112
 
113
  class GraphRecommender(nn.Module):
114
  def __init__(self, input_dim, hidden_dim=128, output_dim=64):
 
123
  x = self.conv2(x, edge_index)
124
  return x
125
 
126
+ def train_model(model, data, pos_edges, neg_edges, epochs=200, patience=20):
127
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
128
+ model = model.to(device)
129
+ data = data.to(device)
130
+ pos_edges = pos_edges.to(device)
131
+ neg_edges = neg_edges.to(device)
132
+
133
  optimizer = optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-4)
134
  best_loss = float('inf')
135
  patience_counter = 0
136
 
137
+ logger.info("Starting model training")
138
  for epoch in range(epochs):
139
  model.train()
140
  optimizer.zero_grad()
 
158
  patience_counter = 0
159
  else:
160
  patience_counter += 1
161
+ if patience_counter >= patience:
162
+ logger.info(f"Early stopping at epoch {epoch}")
163
  break
164
 
165
+ logger.info("Model training completed")
166
+ return model.to('cpu') # Move back to CPU for inference
167
 
168
  def get_recommendations(user_id, model, data, G, user_ids, top_k=10):
169
  if user_id not in user_ids:
170
  return []
171
 
172
  user_idx = user_ids.index(user_id)
 
173
  current_follows = set(G.successors(user_id))
174
+ candidate_indices = [i for i, u in enumerate(user_ids) if u != user_id and u not in current_follows]
175
+
 
 
176
  with torch.no_grad():
177
  embeddings = model(data.x, data.edge_index)
178
+ user_embed = embeddings[user_idx].unsqueeze(0)
 
179
  candidate_embeds = embeddings[candidate_indices]
180
+ scores = torch.matmul(user_embed, candidate_embeds.T).squeeze()
181
 
182
  top_indices = scores.argsort(descending=True)[:top_k]
183
+ recommendations = [user_ids[candidate_indices[i]] for i in top_indices]
184
+ logger.info(f"Generated {len(recommendations)} recommendations for user {user_id}")
185
+ return recommendations
186
 
187
  def rebuild_model():
188
  global G, features, user_ids, pyg_data, trained_model
189
+ logger.info("Starting model rebuild at 3:30 AM")
190
+ try:
191
+ merged_data = load_and_preprocess_data()
192
+ G, features, user_ids = create_graph_dataframe(merged_data)
193
+ pyg_data = from_networkx(G)
194
+ pyg_data.x = features
195
+
196
+ pos_edge_index, neg_edge_index = prepare_training_data(G, user_ids)
197
+ model = GraphRecommender(input_dim=len(user_ids))
198
+ trained_model = train_model(model, pyg_data, pos_edge_index, neg_edge_index)
199
+ logger.info("Model rebuild completed successfully")
200
+ except Exception as e:
201
+ logger.error(f"Error during model rebuild: {str(e)}")
202
+ raise
203
 
204
  @app.post("/rebuild")
205
  async def rebuild_handler():
 
212
  raise HTTPException(status_code=500, detail="Model not initialized, please rebuild first.")
213
 
214
  recommendations = get_recommendations(user_id, trained_model, pyg_data, G, user_ids)
215
+
216
+ # Stream the response
217
+ def generate():
218
+ yield '{"status": "success", "recommendations": ['
219
+ for i, rec in enumerate(recommendations):
220
+ yield json.dumps(rec)
221
+ if i < len(recommendations) - 1:
222
+ yield ','
223
+ yield ']}'
224
+
225
+ return StreamingResponse(generate(), media_type="application/json")
226
 
227
  @app.get("/")
228
  async def health_check():
229
  return {"status": "success", "message": "Recommendation service operational"}
230
 
231
+ # Scheduler setup
232
+ scheduler = BackgroundScheduler()
233
+ scheduler.add_job(
234
+ rebuild_model,
235
+ trigger=CronTrigger(hour=3, minute=30), # Run at 3:30 AM every day
236
+ id='daily_model_rebuild',
237
+ replace_existing=True
238
+ )
239
+
240
+ @app.on_event("startup")
241
+ async def startup_event():
242
+ rebuild_model() # Initial build on startup
243
+ scheduler.start()
244
+ logger.info("Scheduler started, model will rebuild daily at 3:30 AM")
245
+
246
+ @app.on_event("shutdown")
247
+ async def shutdown_event():
248
+ scheduler.shutdown()
249
+ logger.info("Scheduler shut down")
250
+
251
+ if __name__ == "__main__":
252
+ import uvicorn
253
+ uvicorn.run(app, host="0.0.0.0", port=8000)