File size: 10,951 Bytes
fd29a2f
 
 
 
 
 
 
 
 
 
 
 
 
 
d070610
fd29a2f
d070610
 
5189742
d070610
 
4ff1bb6
d070610
 
 
 
fd29a2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e96e6d
 
 
 
 
fd29a2f
 
 
 
 
6e96e6d
fd29a2f
 
 
6e96e6d
fd29a2f
d070610
 
 
6e96e6d
d070610
 
 
 
 
 
 
 
 
 
 
 
 
 
6e96e6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d070610
6e96e6d
 
 
 
d070610
 
 
 
 
6e96e6d
 
d070610
 
 
caa7929
d070610
6e96e6d
739b5c0
d070610
 
fd29a2f
caa7929
6e96e6d
caa7929
6e96e6d
fd29a2f
caa7929
fd29a2f
 
d070610
 
fd29a2f
d070610
6e96e6d
fd29a2f
 
6e96e6d
fd29a2f
 
 
 
 
 
 
 
 
 
 
 
d070610
6e96e6d
d070610
 
 
 
 
 
fd29a2f
 
 
 
d070610
fd29a2f
 
 
 
 
 
6e96e6d
 
fd29a2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d070610
 
fd29a2f
 
d070610
739b5c0
fd29a2f
caa7929
6e96e6d
caa7929
6e96e6d
fd29a2f
 
caa7929
6e96e6d
d070610
 
6e96e6d
 
 
 
fd29a2f
 
d070610
fd29a2f
d070610
fd29a2f
 
d070610
 
 
fd29a2f
 
6e96e6d
caa7929
739b5c0
d070610
 
 
 
 
 
 
 
 
 
 
 
 
fd29a2f
 
 
6e96e6d
fd29a2f
 
 
 
c865291
6e96e6d
fd29a2f
 
6e96e6d
 
fd29a2f
caa7929
d070610
 
 
 
 
 
 
 
 
 
fd29a2f
 
 
6e96e6d
fd29a2f
 
5189742
 
 
 
 
 
 
 
 
 
 
 
739b5c0
4ff1bb6
d070610
 
5189742
d070610
 
 
5189742
 
 
 
 
 
d070610
 
 
6e96e6d
 
d070610
739b5c0
d070610
 
 
6e96e6d
d070610
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
import os
import random
import itertools
import numpy as np
import networkx as nx
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.utils import from_networkx
from torch_geometric.nn import SAGEConv
from supabase import create_client
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from dotenv import load_dotenv
import json
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.interval import IntervalTrigger
from apscheduler.triggers.cron import CronTrigger
import logging
import pytz

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

load_dotenv()

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Global variables
G = None
features = None
user_ids = None
pyg_data = None
trained_model = None

SUPABASE_URL = os.getenv('supabaseUrl')
SUPABASE_KEY = os.getenv('supabaseAnonKey')

def get_supabase_client():
    """Initialize and return a Supabase client."""
    return create_client(SUPABASE_URL, SUPABASE_KEY)

def load_and_preprocess_data():
    """Load and preprocess follower data from Supabase."""
    supabase = get_supabase_client()
    logger.info("Loading data from Supabase")

    def fetch_table(table, columns, chunk_size=1000):
        """Fetch data from a Supabase table in chunks."""
        offset = 0
        all_data = []
        while True:
            response = supabase.table(table).select(columns).range(offset, offset + chunk_size - 1).execute()
            data = response.data
            if not data:
                break
            all_data.extend(data)
            offset += chunk_size
        return all_data

    followers = fetch_table('followers', 'id, following')
    users = fetch_table('profiles', 'id')

    # Build follower_dict: id (followed) -> list of following (followers)
    follower_dict = {}
    for f in followers:
        followed_id = f['id']  # The user being followed
        follower_id = f['following']  # The user following the id
        if not follower_id or not followed_id:  # Skip invalid entries
            logger.warning(f"Skipping invalid entry: follower_id={follower_id}, followed_id={followed_id}")
            continue
        if followed_id in follower_dict:
            follower_dict[followed_id].append(follower_id)
        else:
            follower_dict[followed_id] = [follower_id]

    user_set = set(u['id'] for u in users if u['id'])  # Valid user IDs
    # Create edge list: follower (following) -> followed (id)
    merged = [
        {'follower_id': follower, 'followed_id': fid}
        for fid in follower_dict
        for follower in follower_dict[fid]
        if fid in user_set and follower in user_set
    ]
    logger.info(f"Loaded {len(merged)} follower relationships")
    return merged

def create_graph_dataframe(merged_data):
    """Create a directed graph and feature matrix from merged data."""
    global G, features, user_ids
    G = nx.DiGraph()
    edges = [(d['follower_id'], d['followed_id']) for d in merged_data]
    G.add_edges_from(edges)
    user_ids = sorted(G.nodes())
    
    # Use identity matrix as node features
    features = torch.eye(len(user_ids))
    logger.info(f"Created graph with {len(user_ids)} nodes")
    return G, features, user_ids

def prepare_training_data(G, user_ids):
    """Prepare positive and negative edge indices for training."""
    pos_edges = [(user_ids.index(u), user_ids.index(v)) for u, v in G.edges()]
    pos_edge_index = torch.tensor(pos_edges, dtype=torch.long).t()

    num_nodes = len(user_ids)
    all_possible_edges = set(itertools.permutations(range(num_nodes), 2))
    existing_edges = set(zip(pos_edge_index[0].tolist(), pos_edge_index[1].tolist()))
    neg_sample_size = len(pos_edges)
    negative_edges = random.sample(list(all_possible_edges - existing_edges), neg_sample_size)

    logger.info(f"Prepared {len(pos_edges)} positive and {len(negative_edges)} negative edges")
    return pos_edge_index, torch.tensor(negative_edges, dtype=torch.long).t()

class GraphRecommender(nn.Module):
    """GraphSAGE-based recommendation model."""
    def __init__(self, input_dim, hidden_dim=128, output_dim=64):
        super().__init__()
        self.conv1 = SAGEConv(input_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, output_dim)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return x

def train_model(model, data, pos_edges, neg_edges, epochs=200, patience=20):
    """Train the GraphRecommender model."""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    data = data.to(device)
    pos_edges = pos_edges.to(device)
    neg_edges = neg_edges.to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-4)
    best_loss = float('inf')
    patience_counter = 0

    logger.info("Starting model training")
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()

        embeddings = model(data.x, data.edge_index)

        pos_scores = (embeddings[pos_edges[0]] * embeddings[pos_edges[1]]).sum(dim=1)
        neg_scores = (embeddings[neg_edges[0]] * embeddings[neg_edges[1]]).sum(dim=1)

        pos_loss = F.binary_cross_entropy_with_logits(pos_scores, torch.ones_like(pos_scores))
        neg_loss = F.binary_cross_entropy_with_logits(neg_scores, torch.zeros_like(neg_scores))
        reg_loss = torch.norm(embeddings, p=2)

        total_loss = pos_loss + neg_loss + 0.001 * reg_loss

        total_loss.backward()
        optimizer.step()

        if total_loss < best_loss:
            best_loss = total_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                logger.info(f"Early stopping at epoch {epoch}")
                break

    logger.info("Model training completed")
    return model.to('cpu')

def get_recommendations(user_id, model, data, G, user_ids, top_k=10):
    """Generate top-k user recommendations excluding current follows."""
    if user_id not in user_ids:
        logger.warning(f"User {user_id} not found in graph")
        return []

    user_idx = user_ids.index(user_id)
    current_follows = set(G.successors(user_id))  # Users this user follows
    candidate_indices = [i for i, u in enumerate(user_ids) if u != user_id and u not in current_follows]

    if not candidate_indices:
        logger.info(f"No new recommendations available for user {user_id}")
        return []

    with torch.no_grad():
        embeddings = model(data.x, data.edge_index)
        user_embed = embeddings[user_idx].unsqueeze(0)
        candidate_embeds = embeddings[candidate_indices]
        scores = torch.matmul(user_embed, candidate_embeds.T).squeeze()

    top_indices = scores.argsort(descending=True)[:top_k]
    recommendations = [user_ids[candidate_indices[i]] for i in top_indices]
    logger.info(f"Generated {len(recommendations)} recommendations for user {user_id}")
    return recommendations

def rebuild_model():
    """Rebuild the graph and retrain the model."""
    global G, features, user_ids, pyg_data, trained_model
    logger.info("Starting model rebuild at 3:30 AM Pacific Time")
    try:
        merged_data = load_and_preprocess_data()
        G, features, user_ids = create_graph_dataframe(merged_data)
        pyg_data = from_networkx(G)
        pyg_data.x = features

        pos_edge_index, neg_edge_index = prepare_training_data(G, user_ids)
        model = GraphRecommender(input_dim=len(user_ids))
        trained_model = train_model(model, pyg_data, pos_edge_index, neg_edge_index)
        logger.info("Model rebuild completed successfully")
    except Exception as e:
        logger.error(f"Error during model rebuild: {str(e)}")
        raise

@app.post("/rebuild")
async def rebuild_handler():
    """API endpoint to manually trigger model rebuild."""
    rebuild_model()
    return {"status": "success", "message": "Model and data rebuilt successfully"}

@app.get("/recommend/network")
async def get_recommendations_handler(user_id: str = Query(...)):
    """API endpoint to get recommendations for a user."""
    if not trained_model:
        raise HTTPException(status_code=500, detail="Model not initialized, please rebuild first.")
    if not user_id.strip():
        raise HTTPException(status_code=400, detail="Invalid user_id")

    recommendations = get_recommendations(user_id, trained_model, pyg_data, G, user_ids)
    
    def generate():
        yield '{"status": "success", "recommendations": ['
        for i, rec in enumerate(recommendations):
            yield json.dumps(rec)
            if i < len(recommendations) - 1:
                yield ','
        yield ']}'

    return StreamingResponse(generate(), media_type="application/json")

@app.get("/")
async def health_check():
    """API endpoint for health check."""
    return {"status": "success", "message": "Recommendation service operational"}

def ping_servers():
    notification_server_url = "https://andykrik-notificationservice.hf.space/"
    film_recommender_server_url = "https://andykrik-filmrecommender.hf.space/"
    try:
        import requests
        requests.get(notification_server_url)
        requests.get(film_recommender_server_url)
        logger.info("Pinged notification and film recommender servers successfully")
    except requests.RequestException as e:
        logger.error(f"Error pinging servers: {str(e)}")
        raise

# Scheduler setup with Pacific Time Zone
scheduler = BackgroundScheduler(timezone="America/Los_Angeles")
scheduler.add_job(
    rebuild_model,
    trigger=CronTrigger(hour=3, minute=30),
    id='daily_model_rebuild',
    replace_existing=True
)
scheduler.add_job(
    ping_servers,
    trigger=IntervalTrigger(hours=1),
    id='hourly_model_ping',
    replace_existing=True
)

@app.on_event("startup")
async def startup_event():
    """Startup event to initialize model and scheduler."""
    rebuild_model()
    scheduler.start()
    logger.info("Scheduler started, model will rebuild daily at 3:30 AM Pacific Time")

@app.on_event("shutdown")
async def shutdown_event():
    """Shutdown event to stop scheduler."""
    scheduler.shutdown()
    logger.info("Scheduler shut down")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)