Spaces:
Running
Running
| import os | |
| import random | |
| import itertools | |
| import numpy as np | |
| import networkx as nx | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| from torch_geometric.utils import from_networkx | |
| from torch_geometric.nn import SAGEConv | |
| from supabase import create_client | |
| from fastapi import FastAPI, HTTPException, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from dotenv import load_dotenv | |
| import json | |
| from apscheduler.schedulers.background import BackgroundScheduler | |
| from apscheduler.triggers.interval import IntervalTrigger | |
| from apscheduler.triggers.cron import CronTrigger | |
| import logging | |
| import pytz | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| load_dotenv() | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| SEED = 42 | |
| random.seed(SEED) | |
| np.random.seed(SEED) | |
| torch.manual_seed(SEED) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(SEED) | |
| # Global variables | |
| G = None | |
| features = None | |
| user_ids = None | |
| pyg_data = None | |
| trained_model = None | |
| SUPABASE_URL = os.getenv('supabaseUrl') | |
| SUPABASE_KEY = os.getenv('supabaseAnonKey') | |
| def get_supabase_client(): | |
| """Initialize and return a Supabase client.""" | |
| return create_client(SUPABASE_URL, SUPABASE_KEY) | |
| def load_and_preprocess_data(): | |
| """Load and preprocess follower data from Supabase.""" | |
| supabase = get_supabase_client() | |
| logger.info("Loading data from Supabase") | |
| def fetch_table(table, columns, chunk_size=1000): | |
| """Fetch data from a Supabase table in chunks.""" | |
| offset = 0 | |
| all_data = [] | |
| while True: | |
| response = supabase.table(table).select(columns).range(offset, offset + chunk_size - 1).execute() | |
| data = response.data | |
| if not data: | |
| break | |
| all_data.extend(data) | |
| offset += chunk_size | |
| return all_data | |
| followers = fetch_table('followers', 'id, following') | |
| users = fetch_table('profiles', 'id') | |
| # Build follower_dict: id (followed) -> list of following (followers) | |
| follower_dict = {} | |
| for f in followers: | |
| followed_id = f['id'] # The user being followed | |
| follower_id = f['following'] # The user following the id | |
| if not follower_id or not followed_id: # Skip invalid entries | |
| logger.warning(f"Skipping invalid entry: follower_id={follower_id}, followed_id={followed_id}") | |
| continue | |
| if followed_id in follower_dict: | |
| follower_dict[followed_id].append(follower_id) | |
| else: | |
| follower_dict[followed_id] = [follower_id] | |
| user_set = set(u['id'] for u in users if u['id']) # Valid user IDs | |
| # Create edge list: follower (following) -> followed (id) | |
| merged = [ | |
| {'follower_id': follower, 'followed_id': fid} | |
| for fid in follower_dict | |
| for follower in follower_dict[fid] | |
| if fid in user_set and follower in user_set | |
| ] | |
| logger.info(f"Loaded {len(merged)} follower relationships") | |
| return merged | |
| def create_graph_dataframe(merged_data): | |
| """Create a directed graph and feature matrix from merged data.""" | |
| global G, features, user_ids | |
| G = nx.DiGraph() | |
| edges = [(d['follower_id'], d['followed_id']) for d in merged_data] | |
| G.add_edges_from(edges) | |
| user_ids = sorted(G.nodes()) | |
| # Use identity matrix as node features | |
| features = torch.eye(len(user_ids)) | |
| logger.info(f"Created graph with {len(user_ids)} nodes") | |
| return G, features, user_ids | |
| def prepare_training_data(G, user_ids): | |
| """Prepare positive and negative edge indices for training.""" | |
| pos_edges = [(user_ids.index(u), user_ids.index(v)) for u, v in G.edges()] | |
| pos_edge_index = torch.tensor(pos_edges, dtype=torch.long).t() | |
| num_nodes = len(user_ids) | |
| all_possible_edges = set(itertools.permutations(range(num_nodes), 2)) | |
| existing_edges = set(zip(pos_edge_index[0].tolist(), pos_edge_index[1].tolist())) | |
| neg_sample_size = len(pos_edges) | |
| negative_edges = random.sample(list(all_possible_edges - existing_edges), neg_sample_size) | |
| logger.info(f"Prepared {len(pos_edges)} positive and {len(negative_edges)} negative edges") | |
| return pos_edge_index, torch.tensor(negative_edges, dtype=torch.long).t() | |
| class GraphRecommender(nn.Module): | |
| """GraphSAGE-based recommendation model.""" | |
| def __init__(self, input_dim, hidden_dim=128, output_dim=64): | |
| super().__init__() | |
| self.conv1 = SAGEConv(input_dim, hidden_dim) | |
| self.conv2 = SAGEConv(hidden_dim, output_dim) | |
| self.dropout = nn.Dropout(0.3) | |
| def forward(self, x, edge_index): | |
| x = F.relu(self.conv1(x, edge_index)) | |
| x = self.dropout(x) | |
| x = self.conv2(x, edge_index) | |
| return x | |
| def train_model(model, data, pos_edges, neg_edges, epochs=200, patience=20): | |
| """Train the GraphRecommender model.""" | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = model.to(device) | |
| data = data.to(device) | |
| pos_edges = pos_edges.to(device) | |
| neg_edges = neg_edges.to(device) | |
| optimizer = optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-4) | |
| best_loss = float('inf') | |
| patience_counter = 0 | |
| logger.info("Starting model training") | |
| for epoch in range(epochs): | |
| model.train() | |
| optimizer.zero_grad() | |
| embeddings = model(data.x, data.edge_index) | |
| pos_scores = (embeddings[pos_edges[0]] * embeddings[pos_edges[1]]).sum(dim=1) | |
| neg_scores = (embeddings[neg_edges[0]] * embeddings[neg_edges[1]]).sum(dim=1) | |
| pos_loss = F.binary_cross_entropy_with_logits(pos_scores, torch.ones_like(pos_scores)) | |
| neg_loss = F.binary_cross_entropy_with_logits(neg_scores, torch.zeros_like(neg_scores)) | |
| reg_loss = torch.norm(embeddings, p=2) | |
| total_loss = pos_loss + neg_loss + 0.001 * reg_loss | |
| total_loss.backward() | |
| optimizer.step() | |
| if total_loss < best_loss: | |
| best_loss = total_loss | |
| patience_counter = 0 | |
| else: | |
| patience_counter += 1 | |
| if patience_counter >= patience: | |
| logger.info(f"Early stopping at epoch {epoch}") | |
| break | |
| logger.info("Model training completed") | |
| return model.to('cpu') | |
| def get_recommendations(user_id, model, data, G, user_ids, top_k=10): | |
| """Generate top-k user recommendations excluding current follows.""" | |
| if user_id not in user_ids: | |
| logger.warning(f"User {user_id} not found in graph") | |
| return [] | |
| user_idx = user_ids.index(user_id) | |
| current_follows = set(G.successors(user_id)) # Users this user follows | |
| candidate_indices = [i for i, u in enumerate(user_ids) if u != user_id and u not in current_follows] | |
| if not candidate_indices: | |
| logger.info(f"No new recommendations available for user {user_id}") | |
| return [] | |
| with torch.no_grad(): | |
| embeddings = model(data.x, data.edge_index) | |
| user_embed = embeddings[user_idx].unsqueeze(0) | |
| candidate_embeds = embeddings[candidate_indices] | |
| scores = torch.matmul(user_embed, candidate_embeds.T).squeeze() | |
| top_indices = scores.argsort(descending=True)[:top_k] | |
| recommendations = [user_ids[candidate_indices[i]] for i in top_indices] | |
| logger.info(f"Generated {len(recommendations)} recommendations for user {user_id}") | |
| return recommendations | |
| def rebuild_model(): | |
| """Rebuild the graph and retrain the model.""" | |
| global G, features, user_ids, pyg_data, trained_model | |
| logger.info("Starting model rebuild at 3:30 AM Pacific Time") | |
| try: | |
| merged_data = load_and_preprocess_data() | |
| G, features, user_ids = create_graph_dataframe(merged_data) | |
| pyg_data = from_networkx(G) | |
| pyg_data.x = features | |
| pos_edge_index, neg_edge_index = prepare_training_data(G, user_ids) | |
| model = GraphRecommender(input_dim=len(user_ids)) | |
| trained_model = train_model(model, pyg_data, pos_edge_index, neg_edge_index) | |
| logger.info("Model rebuild completed successfully") | |
| except Exception as e: | |
| logger.error(f"Error during model rebuild: {str(e)}") | |
| raise | |
| async def rebuild_handler(): | |
| """API endpoint to manually trigger model rebuild.""" | |
| rebuild_model() | |
| return {"status": "success", "message": "Model and data rebuilt successfully"} | |
| async def get_recommendations_handler(user_id: str = Query(...)): | |
| """API endpoint to get recommendations for a user.""" | |
| if not trained_model: | |
| raise HTTPException(status_code=500, detail="Model not initialized, please rebuild first.") | |
| if not user_id.strip(): | |
| raise HTTPException(status_code=400, detail="Invalid user_id") | |
| recommendations = get_recommendations(user_id, trained_model, pyg_data, G, user_ids) | |
| def generate(): | |
| yield '{"status": "success", "recommendations": [' | |
| for i, rec in enumerate(recommendations): | |
| yield json.dumps(rec) | |
| if i < len(recommendations) - 1: | |
| yield ',' | |
| yield ']}' | |
| return StreamingResponse(generate(), media_type="application/json") | |
| async def health_check(): | |
| """API endpoint for health check.""" | |
| return {"status": "success", "message": "Recommendation service operational"} | |
| def ping_servers(): | |
| notification_server_url = "https://andykrik-notificationservice.hf.space/" | |
| film_recommender_server_url = "https://andykrik-filmrecommender.hf.space/" | |
| try: | |
| import requests | |
| requests.get(notification_server_url) | |
| requests.get(film_recommender_server_url) | |
| logger.info("Pinged notification and film recommender servers successfully") | |
| except requests.RequestException as e: | |
| logger.error(f"Error pinging servers: {str(e)}") | |
| raise | |
| # Scheduler setup with Pacific Time Zone | |
| scheduler = BackgroundScheduler(timezone="America/Los_Angeles") | |
| scheduler.add_job( | |
| rebuild_model, | |
| trigger=CronTrigger(hour=3, minute=30), | |
| id='daily_model_rebuild', | |
| replace_existing=True | |
| ) | |
| scheduler.add_job( | |
| ping_servers, | |
| trigger=IntervalTrigger(hours=1), | |
| id='hourly_model_ping', | |
| replace_existing=True | |
| ) | |
| async def startup_event(): | |
| """Startup event to initialize model and scheduler.""" | |
| rebuild_model() | |
| scheduler.start() | |
| logger.info("Scheduler started, model will rebuild daily at 3:30 AM Pacific Time") | |
| async def shutdown_event(): | |
| """Shutdown event to stop scheduler.""" | |
| scheduler.shutdown() | |
| logger.info("Scheduler shut down") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |