import os import random import itertools import numpy as np import networkx as nx import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch_geometric.utils import from_networkx from torch_geometric.nn import SAGEConv from supabase import create_client from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from dotenv import load_dotenv import json from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.triggers.interval import IntervalTrigger from apscheduler.triggers.cron import CronTrigger import logging import pytz # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) load_dotenv() app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) SEED = 42 random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED) # Global variables G = None features = None user_ids = None pyg_data = None trained_model = None SUPABASE_URL = os.getenv('supabaseUrl') SUPABASE_KEY = os.getenv('supabaseAnonKey') def get_supabase_client(): """Initialize and return a Supabase client.""" return create_client(SUPABASE_URL, SUPABASE_KEY) def load_and_preprocess_data(): """Load and preprocess follower data from Supabase.""" supabase = get_supabase_client() logger.info("Loading data from Supabase") def fetch_table(table, columns, chunk_size=1000): """Fetch data from a Supabase table in chunks.""" offset = 0 all_data = [] while True: response = supabase.table(table).select(columns).range(offset, offset + chunk_size - 1).execute() data = response.data if not data: break all_data.extend(data) offset += chunk_size return all_data followers = fetch_table('followers', 'id, following') users = fetch_table('profiles', 'id') # Build follower_dict: id (followed) -> list of following (followers) follower_dict = {} for f in followers: followed_id = f['id'] # The user being followed follower_id = f['following'] # The user following the id if not follower_id or not followed_id: # Skip invalid entries logger.warning(f"Skipping invalid entry: follower_id={follower_id}, followed_id={followed_id}") continue if followed_id in follower_dict: follower_dict[followed_id].append(follower_id) else: follower_dict[followed_id] = [follower_id] user_set = set(u['id'] for u in users if u['id']) # Valid user IDs # Create edge list: follower (following) -> followed (id) merged = [ {'follower_id': follower, 'followed_id': fid} for fid in follower_dict for follower in follower_dict[fid] if fid in user_set and follower in user_set ] logger.info(f"Loaded {len(merged)} follower relationships") return merged def create_graph_dataframe(merged_data): """Create a directed graph and feature matrix from merged data.""" global G, features, user_ids G = nx.DiGraph() edges = [(d['follower_id'], d['followed_id']) for d in merged_data] G.add_edges_from(edges) user_ids = sorted(G.nodes()) # Use identity matrix as node features features = torch.eye(len(user_ids)) logger.info(f"Created graph with {len(user_ids)} nodes") return G, features, user_ids def prepare_training_data(G, user_ids): """Prepare positive and negative edge indices for training.""" pos_edges = [(user_ids.index(u), user_ids.index(v)) for u, v in G.edges()] pos_edge_index = torch.tensor(pos_edges, dtype=torch.long).t() num_nodes = len(user_ids) all_possible_edges = set(itertools.permutations(range(num_nodes), 2)) existing_edges = set(zip(pos_edge_index[0].tolist(), pos_edge_index[1].tolist())) neg_sample_size = len(pos_edges) negative_edges = random.sample(list(all_possible_edges - existing_edges), neg_sample_size) logger.info(f"Prepared {len(pos_edges)} positive and {len(negative_edges)} negative edges") return pos_edge_index, torch.tensor(negative_edges, dtype=torch.long).t() class GraphRecommender(nn.Module): """GraphSAGE-based recommendation model.""" def __init__(self, input_dim, hidden_dim=128, output_dim=64): super().__init__() self.conv1 = SAGEConv(input_dim, hidden_dim) self.conv2 = SAGEConv(hidden_dim, output_dim) self.dropout = nn.Dropout(0.3) def forward(self, x, edge_index): x = F.relu(self.conv1(x, edge_index)) x = self.dropout(x) x = self.conv2(x, edge_index) return x def train_model(model, data, pos_edges, neg_edges, epochs=200, patience=20): """Train the GraphRecommender model.""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) data = data.to(device) pos_edges = pos_edges.to(device) neg_edges = neg_edges.to(device) optimizer = optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-4) best_loss = float('inf') patience_counter = 0 logger.info("Starting model training") for epoch in range(epochs): model.train() optimizer.zero_grad() embeddings = model(data.x, data.edge_index) pos_scores = (embeddings[pos_edges[0]] * embeddings[pos_edges[1]]).sum(dim=1) neg_scores = (embeddings[neg_edges[0]] * embeddings[neg_edges[1]]).sum(dim=1) pos_loss = F.binary_cross_entropy_with_logits(pos_scores, torch.ones_like(pos_scores)) neg_loss = F.binary_cross_entropy_with_logits(neg_scores, torch.zeros_like(neg_scores)) reg_loss = torch.norm(embeddings, p=2) total_loss = pos_loss + neg_loss + 0.001 * reg_loss total_loss.backward() optimizer.step() if total_loss < best_loss: best_loss = total_loss patience_counter = 0 else: patience_counter += 1 if patience_counter >= patience: logger.info(f"Early stopping at epoch {epoch}") break logger.info("Model training completed") return model.to('cpu') def get_recommendations(user_id, model, data, G, user_ids, top_k=10): """Generate top-k user recommendations excluding current follows.""" if user_id not in user_ids: logger.warning(f"User {user_id} not found in graph") return [] user_idx = user_ids.index(user_id) current_follows = set(G.successors(user_id)) # Users this user follows candidate_indices = [i for i, u in enumerate(user_ids) if u != user_id and u not in current_follows] if not candidate_indices: logger.info(f"No new recommendations available for user {user_id}") return [] with torch.no_grad(): embeddings = model(data.x, data.edge_index) user_embed = embeddings[user_idx].unsqueeze(0) candidate_embeds = embeddings[candidate_indices] scores = torch.matmul(user_embed, candidate_embeds.T).squeeze() top_indices = scores.argsort(descending=True)[:top_k] recommendations = [user_ids[candidate_indices[i]] for i in top_indices] logger.info(f"Generated {len(recommendations)} recommendations for user {user_id}") return recommendations def rebuild_model(): """Rebuild the graph and retrain the model.""" global G, features, user_ids, pyg_data, trained_model logger.info("Starting model rebuild at 3:30 AM Pacific Time") try: merged_data = load_and_preprocess_data() G, features, user_ids = create_graph_dataframe(merged_data) pyg_data = from_networkx(G) pyg_data.x = features pos_edge_index, neg_edge_index = prepare_training_data(G, user_ids) model = GraphRecommender(input_dim=len(user_ids)) trained_model = train_model(model, pyg_data, pos_edge_index, neg_edge_index) logger.info("Model rebuild completed successfully") except Exception as e: logger.error(f"Error during model rebuild: {str(e)}") raise @app.post("/rebuild") async def rebuild_handler(): """API endpoint to manually trigger model rebuild.""" rebuild_model() return {"status": "success", "message": "Model and data rebuilt successfully"} @app.get("/recommend/network") async def get_recommendations_handler(user_id: str = Query(...)): """API endpoint to get recommendations for a user.""" if not trained_model: raise HTTPException(status_code=500, detail="Model not initialized, please rebuild first.") if not user_id.strip(): raise HTTPException(status_code=400, detail="Invalid user_id") recommendations = get_recommendations(user_id, trained_model, pyg_data, G, user_ids) def generate(): yield '{"status": "success", "recommendations": [' for i, rec in enumerate(recommendations): yield json.dumps(rec) if i < len(recommendations) - 1: yield ',' yield ']}' return StreamingResponse(generate(), media_type="application/json") @app.get("/") async def health_check(): """API endpoint for health check.""" return {"status": "success", "message": "Recommendation service operational"} def ping_servers(): notification_server_url = "https://andykrik-notificationservice.hf.space/" film_recommender_server_url = "https://andykrik-filmrecommender.hf.space/" try: import requests requests.get(notification_server_url) requests.get(film_recommender_server_url) logger.info("Pinged notification and film recommender servers successfully") except requests.RequestException as e: logger.error(f"Error pinging servers: {str(e)}") raise # Scheduler setup with Pacific Time Zone scheduler = BackgroundScheduler(timezone="America/Los_Angeles") scheduler.add_job( rebuild_model, trigger=CronTrigger(hour=3, minute=30), id='daily_model_rebuild', replace_existing=True ) scheduler.add_job( ping_servers, trigger=IntervalTrigger(hours=1), id='hourly_model_ping', replace_existing=True ) @app.on_event("startup") async def startup_event(): """Startup event to initialize model and scheduler.""" rebuild_model() scheduler.start() logger.info("Scheduler started, model will rebuild daily at 3:30 AM Pacific Time") @app.on_event("shutdown") async def shutdown_event(): """Shutdown event to stop scheduler.""" scheduler.shutdown() logger.info("Scheduler shut down") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)