andykr1k
empty
a988705
import os
import random
import itertools
import numpy as np
import networkx as nx
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.utils import from_networkx
from torch_geometric.nn import SAGEConv
from supabase import create_client
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from dotenv import load_dotenv
import json
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.interval import IntervalTrigger
from apscheduler.triggers.cron import CronTrigger
import logging
import pytz
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
load_dotenv()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
# Global variables
G = None
features = None
user_ids = None
pyg_data = None
trained_model = None
SUPABASE_URL = os.getenv('supabaseUrl')
SUPABASE_KEY = os.getenv('supabaseAnonKey')
def get_supabase_client():
"""Initialize and return a Supabase client."""
return create_client(SUPABASE_URL, SUPABASE_KEY)
def load_and_preprocess_data():
"""Load and preprocess follower data from Supabase."""
supabase = get_supabase_client()
logger.info("Loading data from Supabase")
def fetch_table(table, columns, chunk_size=1000):
"""Fetch data from a Supabase table in chunks."""
offset = 0
all_data = []
while True:
response = supabase.table(table).select(columns).range(offset, offset + chunk_size - 1).execute()
data = response.data
if not data:
break
all_data.extend(data)
offset += chunk_size
return all_data
followers = fetch_table('followers', 'id, following')
users = fetch_table('profiles', 'id')
# Build follower_dict: id (followed) -> list of following (followers)
follower_dict = {}
for f in followers:
followed_id = f['id'] # The user being followed
follower_id = f['following'] # The user following the id
if not follower_id or not followed_id: # Skip invalid entries
logger.warning(f"Skipping invalid entry: follower_id={follower_id}, followed_id={followed_id}")
continue
if followed_id in follower_dict:
follower_dict[followed_id].append(follower_id)
else:
follower_dict[followed_id] = [follower_id]
user_set = set(u['id'] for u in users if u['id']) # Valid user IDs
# Create edge list: follower (following) -> followed (id)
merged = [
{'follower_id': follower, 'followed_id': fid}
for fid in follower_dict
for follower in follower_dict[fid]
if fid in user_set and follower in user_set
]
logger.info(f"Loaded {len(merged)} follower relationships")
return merged
def create_graph_dataframe(merged_data):
"""Create a directed graph and feature matrix from merged data."""
global G, features, user_ids
G = nx.DiGraph()
edges = [(d['follower_id'], d['followed_id']) for d in merged_data]
G.add_edges_from(edges)
user_ids = sorted(G.nodes())
# Use identity matrix as node features
features = torch.eye(len(user_ids))
logger.info(f"Created graph with {len(user_ids)} nodes")
return G, features, user_ids
def prepare_training_data(G, user_ids):
"""Prepare positive and negative edge indices for training."""
pos_edges = [(user_ids.index(u), user_ids.index(v)) for u, v in G.edges()]
pos_edge_index = torch.tensor(pos_edges, dtype=torch.long).t()
num_nodes = len(user_ids)
all_possible_edges = set(itertools.permutations(range(num_nodes), 2))
existing_edges = set(zip(pos_edge_index[0].tolist(), pos_edge_index[1].tolist()))
neg_sample_size = len(pos_edges)
negative_edges = random.sample(list(all_possible_edges - existing_edges), neg_sample_size)
logger.info(f"Prepared {len(pos_edges)} positive and {len(negative_edges)} negative edges")
return pos_edge_index, torch.tensor(negative_edges, dtype=torch.long).t()
class GraphRecommender(nn.Module):
"""GraphSAGE-based recommendation model."""
def __init__(self, input_dim, hidden_dim=128, output_dim=64):
super().__init__()
self.conv1 = SAGEConv(input_dim, hidden_dim)
self.conv2 = SAGEConv(hidden_dim, output_dim)
self.dropout = nn.Dropout(0.3)
def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = self.dropout(x)
x = self.conv2(x, edge_index)
return x
def train_model(model, data, pos_edges, neg_edges, epochs=200, patience=20):
"""Train the GraphRecommender model."""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
data = data.to(device)
pos_edges = pos_edges.to(device)
neg_edges = neg_edges.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-4)
best_loss = float('inf')
patience_counter = 0
logger.info("Starting model training")
for epoch in range(epochs):
model.train()
optimizer.zero_grad()
embeddings = model(data.x, data.edge_index)
pos_scores = (embeddings[pos_edges[0]] * embeddings[pos_edges[1]]).sum(dim=1)
neg_scores = (embeddings[neg_edges[0]] * embeddings[neg_edges[1]]).sum(dim=1)
pos_loss = F.binary_cross_entropy_with_logits(pos_scores, torch.ones_like(pos_scores))
neg_loss = F.binary_cross_entropy_with_logits(neg_scores, torch.zeros_like(neg_scores))
reg_loss = torch.norm(embeddings, p=2)
total_loss = pos_loss + neg_loss + 0.001 * reg_loss
total_loss.backward()
optimizer.step()
if total_loss < best_loss:
best_loss = total_loss
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= patience:
logger.info(f"Early stopping at epoch {epoch}")
break
logger.info("Model training completed")
return model.to('cpu')
def get_recommendations(user_id, model, data, G, user_ids, top_k=10):
"""Generate top-k user recommendations excluding current follows."""
if user_id not in user_ids:
logger.warning(f"User {user_id} not found in graph")
return []
user_idx = user_ids.index(user_id)
current_follows = set(G.successors(user_id)) # Users this user follows
candidate_indices = [i for i, u in enumerate(user_ids) if u != user_id and u not in current_follows]
if not candidate_indices:
logger.info(f"No new recommendations available for user {user_id}")
return []
with torch.no_grad():
embeddings = model(data.x, data.edge_index)
user_embed = embeddings[user_idx].unsqueeze(0)
candidate_embeds = embeddings[candidate_indices]
scores = torch.matmul(user_embed, candidate_embeds.T).squeeze()
top_indices = scores.argsort(descending=True)[:top_k]
recommendations = [user_ids[candidate_indices[i]] for i in top_indices]
logger.info(f"Generated {len(recommendations)} recommendations for user {user_id}")
return recommendations
def rebuild_model():
"""Rebuild the graph and retrain the model."""
global G, features, user_ids, pyg_data, trained_model
logger.info("Starting model rebuild at 3:30 AM Pacific Time")
try:
merged_data = load_and_preprocess_data()
G, features, user_ids = create_graph_dataframe(merged_data)
pyg_data = from_networkx(G)
pyg_data.x = features
pos_edge_index, neg_edge_index = prepare_training_data(G, user_ids)
model = GraphRecommender(input_dim=len(user_ids))
trained_model = train_model(model, pyg_data, pos_edge_index, neg_edge_index)
logger.info("Model rebuild completed successfully")
except Exception as e:
logger.error(f"Error during model rebuild: {str(e)}")
raise
@app.post("/rebuild")
async def rebuild_handler():
"""API endpoint to manually trigger model rebuild."""
rebuild_model()
return {"status": "success", "message": "Model and data rebuilt successfully"}
@app.get("/recommend/network")
async def get_recommendations_handler(user_id: str = Query(...)):
"""API endpoint to get recommendations for a user."""
if not trained_model:
raise HTTPException(status_code=500, detail="Model not initialized, please rebuild first.")
if not user_id.strip():
raise HTTPException(status_code=400, detail="Invalid user_id")
recommendations = get_recommendations(user_id, trained_model, pyg_data, G, user_ids)
def generate():
yield '{"status": "success", "recommendations": ['
for i, rec in enumerate(recommendations):
yield json.dumps(rec)
if i < len(recommendations) - 1:
yield ','
yield ']}'
return StreamingResponse(generate(), media_type="application/json")
@app.get("/")
async def health_check():
"""API endpoint for health check."""
return {"status": "success", "message": "Recommendation service operational"}
def ping_servers():
notification_server_url = "https://andykrik-notificationservice.hf.space/"
film_recommender_server_url = "https://andykrik-filmrecommender.hf.space/"
try:
import requests
requests.get(notification_server_url)
requests.get(film_recommender_server_url)
logger.info("Pinged notification and film recommender servers successfully")
except requests.RequestException as e:
logger.error(f"Error pinging servers: {str(e)}")
raise
# Scheduler setup with Pacific Time Zone
scheduler = BackgroundScheduler(timezone="America/Los_Angeles")
scheduler.add_job(
rebuild_model,
trigger=CronTrigger(hour=3, minute=30),
id='daily_model_rebuild',
replace_existing=True
)
scheduler.add_job(
ping_servers,
trigger=IntervalTrigger(hours=1),
id='hourly_model_ping',
replace_existing=True
)
@app.on_event("startup")
async def startup_event():
"""Startup event to initialize model and scheduler."""
rebuild_model()
scheduler.start()
logger.info("Scheduler started, model will rebuild daily at 3:30 AM Pacific Time")
@app.on_event("shutdown")
async def shutdown_event():
"""Shutdown event to stop scheduler."""
scheduler.shutdown()
logger.info("Scheduler shut down")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)