Spaces:
Running
Running
andykr1k commited on
Commit ·
caa7929
1
Parent(s): fd29a2f
Changed to using user id
Browse files
app.py
CHANGED
|
@@ -36,14 +36,13 @@ if torch.cuda.is_available():
|
|
| 36 |
torch.cuda.manual_seed_all(SEED)
|
| 37 |
|
| 38 |
# Global variables
|
| 39 |
-
global G, features,
|
| 40 |
G = None
|
| 41 |
features = None
|
| 42 |
-
|
| 43 |
pyg_data = None
|
| 44 |
trained_model = None
|
| 45 |
|
| 46 |
-
SUPABASE_ID = os.getenv('supabaseID')
|
| 47 |
SUPABASE_URL = os.getenv('supabaseUrl')
|
| 48 |
SUPABASE_KEY = os.getenv('supabaseAnonKey')
|
| 49 |
|
|
@@ -53,32 +52,27 @@ def get_supabase_client():
|
|
| 53 |
def load_and_preprocess_data():
|
| 54 |
supabase = get_supabase_client()
|
| 55 |
followers_response = supabase.table('followers').select('*').execute()
|
| 56 |
-
users_response = supabase.table('profiles').select('id
|
| 57 |
|
| 58 |
followers = pd.DataFrame(followers_response.data)
|
| 59 |
users = pd.DataFrame(users_response.data)
|
| 60 |
|
| 61 |
-
merged = followers.merge(users
|
| 62 |
-
|
| 63 |
-
merged = merged.rename(columns={'username': 'follower_username'}).drop(columns=['id_y'])
|
| 64 |
|
| 65 |
-
merged = merged
|
| 66 |
-
|
| 67 |
-
merged = merged.rename(columns={'username': 'followed_username'})
|
| 68 |
-
|
| 69 |
-
merged = merged[['follower_username', 'followed_username']].dropna()
|
| 70 |
-
return merged[(merged['follower_username'] != '') & (merged['followed_username'] != '')]
|
| 71 |
|
| 72 |
def create_graph_dataframe(merged_df):
|
| 73 |
-
G = nx.from_pandas_edgelist(merged_df, source='
|
| 74 |
-
|
| 75 |
-
return G, torch.eye(len(
|
| 76 |
|
| 77 |
-
def prepare_training_data(G,
|
| 78 |
-
pos_edges = [(
|
| 79 |
pos_edge_index = torch.tensor(pos_edges).T
|
| 80 |
|
| 81 |
-
num_nodes = len(
|
| 82 |
all_possible_edges = set(itertools.permutations(range(num_nodes), 2))
|
| 83 |
existing_edges = set(zip(pos_edge_index[0].tolist(), pos_edge_index[1].tolist()))
|
| 84 |
negative_edges = random.sample(list(all_possible_edges - existing_edges), len(pos_edges))
|
|
@@ -131,20 +125,20 @@ def train_model(model, data, pos_edges, neg_edges, epochs=200):
|
|
| 131 |
|
| 132 |
return model
|
| 133 |
|
| 134 |
-
def get_recommendations(
|
| 135 |
-
if
|
| 136 |
return []
|
| 137 |
|
| 138 |
-
user_idx =
|
| 139 |
-
current_follows = set(G.successors(
|
| 140 |
|
| 141 |
-
candidates = [u for u in
|
| 142 |
|
| 143 |
with torch.no_grad():
|
| 144 |
embeddings = model(data.x, data.edge_index)
|
| 145 |
user_embed = embeddings[user_idx]
|
| 146 |
|
| 147 |
-
candidate_indices = [
|
| 148 |
candidate_embeds = embeddings[candidate_indices]
|
| 149 |
|
| 150 |
scores = torch.mm(user_embed.view(1, -1), candidate_embeds.T).squeeze()
|
|
@@ -153,14 +147,14 @@ def get_recommendations(username, model, data, G, usernames, top_k=10):
|
|
| 153 |
return [candidates[i] for i in top_indices]
|
| 154 |
|
| 155 |
def rebuild_model():
|
| 156 |
-
global G, features,
|
| 157 |
merged_df = load_and_preprocess_data()
|
| 158 |
-
G, features,
|
| 159 |
pyg_data = from_networkx(G)
|
| 160 |
pyg_data.x = features
|
| 161 |
|
| 162 |
-
pos_edge_index, neg_edge_index = prepare_training_data(G,
|
| 163 |
-
model = GraphRecommender(input_dim=len(
|
| 164 |
trained_model = train_model(model, pyg_data, pos_edge_index, neg_edge_index)
|
| 165 |
|
| 166 |
@app.post("/rebuild")
|
|
@@ -169,11 +163,11 @@ async def rebuild_handler():
|
|
| 169 |
return {"status": "success", "message": "Model and data rebuilt successfully"}
|
| 170 |
|
| 171 |
@app.get("/recommend/network")
|
| 172 |
-
async def get_recommendations_handler(
|
| 173 |
if not trained_model:
|
| 174 |
raise HTTPException(status_code=500, detail="Model not initialized, please rebuild first.")
|
| 175 |
|
| 176 |
-
recommendations = get_recommendations(
|
| 177 |
return {"status": "success", "recommendations": recommendations}
|
| 178 |
|
| 179 |
@app.get("/")
|
|
|
|
| 36 |
torch.cuda.manual_seed_all(SEED)
|
| 37 |
|
| 38 |
# Global variables
|
| 39 |
+
global G, features, user_ids, pyg_data, trained_model
|
| 40 |
G = None
|
| 41 |
features = None
|
| 42 |
+
user_ids = None
|
| 43 |
pyg_data = None
|
| 44 |
trained_model = None
|
| 45 |
|
|
|
|
| 46 |
SUPABASE_URL = os.getenv('supabaseUrl')
|
| 47 |
SUPABASE_KEY = os.getenv('supabaseAnonKey')
|
| 48 |
|
|
|
|
| 52 |
def load_and_preprocess_data():
|
| 53 |
supabase = get_supabase_client()
|
| 54 |
followers_response = supabase.table('followers').select('*').execute()
|
| 55 |
+
users_response = supabase.table('profiles').select('id').execute()
|
| 56 |
|
| 57 |
followers = pd.DataFrame(followers_response.data)
|
| 58 |
users = pd.DataFrame(users_response.data)
|
| 59 |
|
| 60 |
+
merged = followers.merge(users, left_on='following', right_on='id', how='left')
|
| 61 |
+
merged = merged.rename(columns={'id_x': 'follower_id', 'id_y': 'followed_id'})
|
|
|
|
| 62 |
|
| 63 |
+
merged = merged[['follower_id', 'followed_id']].dropna()
|
| 64 |
+
return merged[(merged['follower_id'] != '') & (merged['followed_id'] != '')]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
def create_graph_dataframe(merged_df):
|
| 67 |
+
G = nx.from_pandas_edgelist(merged_df, source='follower_id', target='followed_id', create_using=nx.DiGraph())
|
| 68 |
+
user_ids = sorted(G.nodes())
|
| 69 |
+
return G, torch.eye(len(user_ids)), user_ids
|
| 70 |
|
| 71 |
+
def prepare_training_data(G, user_ids):
|
| 72 |
+
pos_edges = [(user_ids.index(u), user_ids.index(v)) for u, v in G.edges()]
|
| 73 |
pos_edge_index = torch.tensor(pos_edges).T
|
| 74 |
|
| 75 |
+
num_nodes = len(user_ids)
|
| 76 |
all_possible_edges = set(itertools.permutations(range(num_nodes), 2))
|
| 77 |
existing_edges = set(zip(pos_edge_index[0].tolist(), pos_edge_index[1].tolist()))
|
| 78 |
negative_edges = random.sample(list(all_possible_edges - existing_edges), len(pos_edges))
|
|
|
|
| 125 |
|
| 126 |
return model
|
| 127 |
|
| 128 |
+
def get_recommendations(user_id, model, data, G, user_ids, top_k=10):
|
| 129 |
+
if user_id not in user_ids:
|
| 130 |
return []
|
| 131 |
|
| 132 |
+
user_idx = user_ids.index(user_id)
|
| 133 |
+
current_follows = set(G.successors(user_id))
|
| 134 |
|
| 135 |
+
candidates = [u for u in user_ids if u != user_id and u not in current_follows]
|
| 136 |
|
| 137 |
with torch.no_grad():
|
| 138 |
embeddings = model(data.x, data.edge_index)
|
| 139 |
user_embed = embeddings[user_idx]
|
| 140 |
|
| 141 |
+
candidate_indices = [user_ids.index(u) for u in candidates]
|
| 142 |
candidate_embeds = embeddings[candidate_indices]
|
| 143 |
|
| 144 |
scores = torch.mm(user_embed.view(1, -1), candidate_embeds.T).squeeze()
|
|
|
|
| 147 |
return [candidates[i] for i in top_indices]
|
| 148 |
|
| 149 |
def rebuild_model():
|
| 150 |
+
global G, features, user_ids, pyg_data, trained_model
|
| 151 |
merged_df = load_and_preprocess_data()
|
| 152 |
+
G, features, user_ids = create_graph_dataframe(merged_df)
|
| 153 |
pyg_data = from_networkx(G)
|
| 154 |
pyg_data.x = features
|
| 155 |
|
| 156 |
+
pos_edge_index, neg_edge_index = prepare_training_data(G, user_ids)
|
| 157 |
+
model = GraphRecommender(input_dim=len(user_ids))
|
| 158 |
trained_model = train_model(model, pyg_data, pos_edge_index, neg_edge_index)
|
| 159 |
|
| 160 |
@app.post("/rebuild")
|
|
|
|
| 163 |
return {"status": "success", "message": "Model and data rebuilt successfully"}
|
| 164 |
|
| 165 |
@app.get("/recommend/network")
|
| 166 |
+
async def get_recommendations_handler(user_id: int = Query(...)):
|
| 167 |
if not trained_model:
|
| 168 |
raise HTTPException(status_code=500, detail="Model not initialized, please rebuild first.")
|
| 169 |
|
| 170 |
+
recommendations = get_recommendations(user_id, trained_model, pyg_data, G, user_ids)
|
| 171 |
return {"status": "success", "recommendations": recommendations}
|
| 172 |
|
| 173 |
@app.get("/")
|