andykr1k commited on
Commit
caa7929
·
1 Parent(s): fd29a2f

Changed to using user id

Browse files
Files changed (1) hide show
  1. app.py +25 -31
app.py CHANGED
@@ -36,14 +36,13 @@ if torch.cuda.is_available():
36
  torch.cuda.manual_seed_all(SEED)
37
 
38
  # Global variables
39
- global G, features, usernames, pyg_data, trained_model
40
  G = None
41
  features = None
42
- usernames = None
43
  pyg_data = None
44
  trained_model = None
45
 
46
- SUPABASE_ID = os.getenv('supabaseID')
47
  SUPABASE_URL = os.getenv('supabaseUrl')
48
  SUPABASE_KEY = os.getenv('supabaseAnonKey')
49
 
@@ -53,32 +52,27 @@ def get_supabase_client():
53
  def load_and_preprocess_data():
54
  supabase = get_supabase_client()
55
  followers_response = supabase.table('followers').select('*').execute()
56
- users_response = supabase.table('profiles').select('id, username').execute()
57
 
58
  followers = pd.DataFrame(followers_response.data)
59
  users = pd.DataFrame(users_response.data)
60
 
61
- merged = followers.merge(users[['id', 'username']],
62
- left_on='following', right_on='id', how='left')
63
- merged = merged.rename(columns={'username': 'follower_username'}).drop(columns=['id_y'])
64
 
65
- merged = merged.merge(users[['id', 'username']],
66
- left_on='id_x', right_on='id', how='left')
67
- merged = merged.rename(columns={'username': 'followed_username'})
68
-
69
- merged = merged[['follower_username', 'followed_username']].dropna()
70
- return merged[(merged['follower_username'] != '') & (merged['followed_username'] != '')]
71
 
72
  def create_graph_dataframe(merged_df):
73
- G = nx.from_pandas_edgelist(merged_df, source='follower_username', target='followed_username', create_using=nx.DiGraph())
74
- usernames = sorted(G.nodes())
75
- return G, torch.eye(len(usernames)), usernames
76
 
77
- def prepare_training_data(G, usernames):
78
- pos_edges = [(usernames.index(u), usernames.index(v)) for u, v in G.edges()]
79
  pos_edge_index = torch.tensor(pos_edges).T
80
 
81
- num_nodes = len(usernames)
82
  all_possible_edges = set(itertools.permutations(range(num_nodes), 2))
83
  existing_edges = set(zip(pos_edge_index[0].tolist(), pos_edge_index[1].tolist()))
84
  negative_edges = random.sample(list(all_possible_edges - existing_edges), len(pos_edges))
@@ -131,20 +125,20 @@ def train_model(model, data, pos_edges, neg_edges, epochs=200):
131
 
132
  return model
133
 
134
- def get_recommendations(username, model, data, G, usernames, top_k=10):
135
- if username not in usernames:
136
  return []
137
 
138
- user_idx = usernames.index(username)
139
- current_follows = set(G.successors(username))
140
 
141
- candidates = [u for u in usernames if u != username and u not in current_follows]
142
 
143
  with torch.no_grad():
144
  embeddings = model(data.x, data.edge_index)
145
  user_embed = embeddings[user_idx]
146
 
147
- candidate_indices = [usernames.index(u) for u in candidates]
148
  candidate_embeds = embeddings[candidate_indices]
149
 
150
  scores = torch.mm(user_embed.view(1, -1), candidate_embeds.T).squeeze()
@@ -153,14 +147,14 @@ def get_recommendations(username, model, data, G, usernames, top_k=10):
153
  return [candidates[i] for i in top_indices]
154
 
155
  def rebuild_model():
156
- global G, features, usernames, pyg_data, trained_model
157
  merged_df = load_and_preprocess_data()
158
- G, features, usernames = create_graph_dataframe(merged_df)
159
  pyg_data = from_networkx(G)
160
  pyg_data.x = features
161
 
162
- pos_edge_index, neg_edge_index = prepare_training_data(G, usernames)
163
- model = GraphRecommender(input_dim=len(usernames))
164
  trained_model = train_model(model, pyg_data, pos_edge_index, neg_edge_index)
165
 
166
  @app.post("/rebuild")
@@ -169,11 +163,11 @@ async def rebuild_handler():
169
  return {"status": "success", "message": "Model and data rebuilt successfully"}
170
 
171
  @app.get("/recommend/network")
172
- async def get_recommendations_handler(username: str = Query(...)):
173
  if not trained_model:
174
  raise HTTPException(status_code=500, detail="Model not initialized, please rebuild first.")
175
 
176
- recommendations = get_recommendations(username, trained_model, pyg_data, G, usernames)
177
  return {"status": "success", "recommendations": recommendations}
178
 
179
  @app.get("/")
 
36
  torch.cuda.manual_seed_all(SEED)
37
 
38
  # Global variables
39
+ global G, features, user_ids, pyg_data, trained_model
40
  G = None
41
  features = None
42
+ user_ids = None
43
  pyg_data = None
44
  trained_model = None
45
 
 
46
  SUPABASE_URL = os.getenv('supabaseUrl')
47
  SUPABASE_KEY = os.getenv('supabaseAnonKey')
48
 
 
52
  def load_and_preprocess_data():
53
  supabase = get_supabase_client()
54
  followers_response = supabase.table('followers').select('*').execute()
55
+ users_response = supabase.table('profiles').select('id').execute()
56
 
57
  followers = pd.DataFrame(followers_response.data)
58
  users = pd.DataFrame(users_response.data)
59
 
60
+ merged = followers.merge(users, left_on='following', right_on='id', how='left')
61
+ merged = merged.rename(columns={'id_x': 'follower_id', 'id_y': 'followed_id'})
 
62
 
63
+ merged = merged[['follower_id', 'followed_id']].dropna()
64
+ return merged[(merged['follower_id'] != '') & (merged['followed_id'] != '')]
 
 
 
 
65
 
66
  def create_graph_dataframe(merged_df):
67
+ G = nx.from_pandas_edgelist(merged_df, source='follower_id', target='followed_id', create_using=nx.DiGraph())
68
+ user_ids = sorted(G.nodes())
69
+ return G, torch.eye(len(user_ids)), user_ids
70
 
71
+ def prepare_training_data(G, user_ids):
72
+ pos_edges = [(user_ids.index(u), user_ids.index(v)) for u, v in G.edges()]
73
  pos_edge_index = torch.tensor(pos_edges).T
74
 
75
+ num_nodes = len(user_ids)
76
  all_possible_edges = set(itertools.permutations(range(num_nodes), 2))
77
  existing_edges = set(zip(pos_edge_index[0].tolist(), pos_edge_index[1].tolist()))
78
  negative_edges = random.sample(list(all_possible_edges - existing_edges), len(pos_edges))
 
125
 
126
  return model
127
 
128
+ def get_recommendations(user_id, model, data, G, user_ids, top_k=10):
129
+ if user_id not in user_ids:
130
  return []
131
 
132
+ user_idx = user_ids.index(user_id)
133
+ current_follows = set(G.successors(user_id))
134
 
135
+ candidates = [u for u in user_ids if u != user_id and u not in current_follows]
136
 
137
  with torch.no_grad():
138
  embeddings = model(data.x, data.edge_index)
139
  user_embed = embeddings[user_idx]
140
 
141
+ candidate_indices = [user_ids.index(u) for u in candidates]
142
  candidate_embeds = embeddings[candidate_indices]
143
 
144
  scores = torch.mm(user_embed.view(1, -1), candidate_embeds.T).squeeze()
 
147
  return [candidates[i] for i in top_indices]
148
 
149
  def rebuild_model():
150
+ global G, features, user_ids, pyg_data, trained_model
151
  merged_df = load_and_preprocess_data()
152
+ G, features, user_ids = create_graph_dataframe(merged_df)
153
  pyg_data = from_networkx(G)
154
  pyg_data.x = features
155
 
156
+ pos_edge_index, neg_edge_index = prepare_training_data(G, user_ids)
157
+ model = GraphRecommender(input_dim=len(user_ids))
158
  trained_model = train_model(model, pyg_data, pos_edge_index, neg_edge_index)
159
 
160
  @app.post("/rebuild")
 
163
  return {"status": "success", "message": "Model and data rebuilt successfully"}
164
 
165
  @app.get("/recommend/network")
166
+ async def get_recommendations_handler(user_id: int = Query(...)):
167
  if not trained_model:
168
  raise HTTPException(status_code=500, detail="Model not initialized, please rebuild first.")
169
 
170
+ recommendations = get_recommendations(user_id, trained_model, pyg_data, G, user_ids)
171
  return {"status": "success", "recommendations": recommendations}
172
 
173
  @app.get("/")