Spaces:
Sleeping
Sleeping
| import os | |
| from fastapi import FastAPI, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from supabase import create_client | |
| import numpy as np | |
| from collections import defaultdict | |
| import random | |
| from datetime import datetime, timedelta | |
| import asyncio | |
| from typing import List, Dict | |
| from dotenv import load_dotenv | |
| import logging | |
| from zoneinfo import ZoneInfo | |
| from sentence_transformers import SentenceTransformer | |
| from apscheduler.schedulers.background import BackgroundScheduler | |
| import math | |
| from contextlib import asynccontextmanager # Import this for lifespan | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Load environment variables | |
| load_dotenv() | |
| # Configuration | |
| SEED = 42 | |
| random.seed(SEED) | |
| np.random.seed(SEED) | |
| TOP_K = 75 | |
| HISTORY_WINDOW = timedelta(days=1000) | |
| TIMEZONE = ZoneInfo("UTC") | |
| UPDATE_INTERVAL = 300 # In seconds (5 minutes) | |
| START_TIME = datetime.now(TIMEZONE) | |
| # Global variables | |
| supabase_client = None | |
| user_interactions = defaultdict(set) | |
| post_features = {} | |
| post_metadata = {} | |
| if not os.path.exists('/tmp/cache'): | |
| os.makedirs('/tmp/cache') | |
| sentence_model = SentenceTransformer('all-MiniLM-L6-v2', cache_folder='/tmp/cache') | |
| SUPABASE_URL = os.getenv('supabaseUrl') | |
| SUPABASE_KEY = os.getenv('supabaseAnonKey') | |
| def get_supabase_client(): | |
| global supabase_client | |
| if supabase_client is None: | |
| supabase_client = create_client(SUPABASE_URL, SUPABASE_KEY) | |
| return supabase_client | |
| def parse_datetime(dt_str: str) -> datetime: | |
| """Parse ISO datetime string and ensure correct microsecond precision.""" | |
| try: | |
| if '.' in dt_str: | |
| date_part, time_part = dt_str.split('T') | |
| time_part, tz_part = time_part.split('+') | |
| if '.' in time_part: | |
| time_without_micro, micro = time_part.split('.') | |
| micro = micro.ljust(6, '0') | |
| time_part = f"{time_without_micro}.{micro}" | |
| dt_str = f"{date_part}T{time_part}+00:00" | |
| return datetime.fromisoformat(dt_str).astimezone(TIMEZONE) | |
| except Exception as e: | |
| logger.error(f"Error parsing datetime: {dt_str} - {str(e)}") | |
| raise | |
| def decay_weight(interaction_time: datetime, current_time: datetime, decay_rate: float = 0.0001) -> float: | |
| """ | |
| Compute an exponential decay weight for an interaction based on its age. | |
| The decay_rate can be tuned to control how fast interactions lose weight. | |
| """ | |
| time_diff = (current_time - interaction_time).total_seconds() | |
| return math.exp(-decay_rate * time_diff) | |
| def normalize_profile(profile: np.ndarray) -> np.ndarray: | |
| norm = np.linalg.norm(profile) | |
| return profile / norm if norm > 0 else profile | |
| def compute_score(sim_score: float, popularity: float, freshness: float) -> float: | |
| """ | |
| Compute a non-linear recommendation score. | |
| - sim_score is squared to emphasize strong similarities. | |
| - Popularity is log-transformed. | |
| - Freshness is combined linearly. | |
| """ | |
| return 0.4 * (sim_score ** 2) + 0.3 * np.log1p(popularity) + 0.3 * freshness | |
| class Recommender: | |
| def __init__(self, like_weight=1.0, comment_weight=0.5, comment_like_weight=0.3, | |
| reply_weight=0.5, reply_like_weight=0.3): | |
| self.user_profiles = defaultdict(lambda: np.zeros(384)) | |
| self.post_popularity = defaultdict(float) | |
| self.last_update = datetime.now(TIMEZONE) - HISTORY_WINDOW | |
| # Parameterized weights | |
| self.like_weight = like_weight | |
| self.comment_weight = comment_weight | |
| self.comment_like_weight = comment_like_weight | |
| self.reply_weight = reply_weight | |
| self.reply_like_weight = reply_like_weight | |
| async def fetch_existing_post_ids(self) -> set: | |
| supabase = get_supabase_client() | |
| page_size = 1000 | |
| page = 0 | |
| post_ids = set() | |
| while True: | |
| response = await asyncio.to_thread( | |
| supabase.table('posts') | |
| .select('id') | |
| .range(page * page_size, (page + 1) * page_size - 1) | |
| .execute | |
| ) | |
| if not response.data: | |
| break | |
| for row in response.data: | |
| post_ids.add(row['id']) | |
| page += 1 | |
| return post_ids | |
| def clean_deleted_posts(self, existing_post_ids: set): | |
| # Determine which posts are missing | |
| all_cached_ids = set(post_metadata.keys()) | |
| deleted_ids = all_cached_ids - existing_post_ids | |
| for post_id in deleted_ids: | |
| post_metadata.pop(post_id, None) | |
| post_features.pop(post_id, None) | |
| self.post_popularity.pop(post_id, None) | |
| for user_id in user_interactions: | |
| user_interactions[user_id] -= deleted_ids # remove any deleted post_ids | |
| async def fetch_all_rows(self, table_name: str, columns: str, last_update: datetime, post_id_not_null: bool): | |
| """Fetch all rows from a table using pagination.""" | |
| supabase = get_supabase_client() | |
| page_size = 1000 | |
| page = 0 | |
| all_data = [] | |
| while True: | |
| if post_id_not_null: | |
| response = await asyncio.to_thread( | |
| supabase.table(table_name) | |
| .select(columns) | |
| .gt('created_at', last_update.isoformat()) | |
| .not_.is_("post_id", None) | |
| .range(page * page_size, (page + 1) * page_size - 1) | |
| .execute | |
| ) | |
| else: | |
| response = await asyncio.to_thread( | |
| supabase.table(table_name) | |
| .select(columns) | |
| .gt('created_at', last_update.isoformat()) | |
| .range(page * page_size, (page + 1) * page_size - 1) | |
| .execute | |
| ) | |
| if not response.data: | |
| break | |
| all_data.extend(response.data) | |
| page += 1 | |
| return all_data | |
| async def update_data(self): | |
| # Current time for decay calculation | |
| current_time = datetime.now(TIMEZONE) | |
| # Fetch all interaction data since last update | |
| likes = await self.fetch_all_rows('likes', 'user_id, post_id, created_at', self.last_update, True) | |
| comments = await self.fetch_all_rows('comments', 'id, user_id, post_id, created_at, comment', self.last_update, True) | |
| commentlikes = await self.fetch_all_rows('commentlikes', 'user_id, author_id, post_id, created_at, comment_id', self.last_update, True) | |
| replies = await self.fetch_all_rows('replies', 'user_id, to, post_id, created_at, comment, comment_id, reply_id', self.last_update, True) | |
| replylikes = await self.fetch_all_rows('replylikes', 'user_id, author_id, post_id, created_at, reply_id', self.last_update, True) | |
| # Process likes with time decay | |
| for like in likes: | |
| user_id = like['user_id'] | |
| post_id = like['post_id'] | |
| interaction_time = parse_datetime(like['created_at']) | |
| weight = decay_weight(interaction_time, current_time) | |
| user_interactions[user_id].add(post_id) | |
| self.post_popularity[post_id] += self.like_weight * weight | |
| if post_id in post_features: | |
| self.user_profiles[user_id] += post_features[post_id] * \ | |
| self.like_weight * weight | |
| # Process comments with time decay | |
| for comment in comments: | |
| user_id = comment['user_id'] | |
| post_id = comment['post_id'] | |
| interaction_time = parse_datetime(comment['created_at']) | |
| weight = decay_weight(interaction_time, current_time) | |
| user_interactions[user_id].add(post_id) | |
| self.post_popularity[post_id] += self.comment_weight * weight | |
| if post_id in post_features: | |
| self.user_profiles[user_id] += post_features[post_id] * \ | |
| self.comment_weight * weight | |
| # Process comment likes with time decay | |
| for clike in commentlikes: | |
| user_id = clike['user_id'] # User who liked the comment | |
| post_id = clike['post_id'] | |
| interaction_time = parse_datetime(clike['created_at']) | |
| weight = decay_weight(interaction_time, current_time) | |
| user_interactions[user_id].add(post_id) | |
| self.post_popularity[post_id] += self.comment_like_weight * weight | |
| if post_id in post_features: | |
| self.user_profiles[user_id] += post_features[post_id] * \ | |
| self.comment_like_weight * weight | |
| # Process replies with time decay | |
| for reply in replies: | |
| user_id = reply['user_id'] | |
| post_id = reply['post_id'] | |
| interaction_time = parse_datetime(reply['created_at']) | |
| weight = decay_weight(interaction_time, current_time) | |
| user_interactions[user_id].add(post_id) | |
| self.post_popularity[post_id] += self.reply_weight * weight | |
| if post_id in post_features: | |
| self.user_profiles[user_id] += post_features[post_id] * \ | |
| self.reply_weight * weight | |
| # Process reply likes with time decay | |
| for rlike in replylikes: | |
| user_id = rlike['user_id'] | |
| post_id = rlike['post_id'] | |
| interaction_time = parse_datetime(rlike['created_at']) | |
| weight = decay_weight(interaction_time, current_time) | |
| user_interactions[user_id].add(post_id) | |
| self.post_popularity[post_id] += self.reply_like_weight * weight | |
| if post_id in post_features: | |
| self.user_profiles[user_id] += post_features[post_id] * \ | |
| self.reply_like_weight * weight | |
| # OPTIONAL: Process negative feedback if available | |
| # for negative in negative_interactions: | |
| # user_id = negative['user_id'] | |
| # post_id = negative['post_id'] | |
| # interaction_time = parse_datetime(negative['created_at']) | |
| # weight = decay_weight(interaction_time, current_time) | |
| # user_interactions[user_id].add(post_id) | |
| # self.post_popularity[post_id] -= some_negative_weight * weight | |
| # if post_id in post_features: | |
| # self.user_profiles[user_id] -= post_features[post_id] * some_negative_weight * weight | |
| # Normalize user profiles after processing interactions | |
| for user_id in self.user_profiles: | |
| self.user_profiles[user_id] = normalize_profile( | |
| self.user_profiles[user_id]) | |
| # Fetch and update post features | |
| posts = await self.fetch_all_rows('posts', '*', self.last_update, False) | |
| post_texts, post_ids = [], [] | |
| for post in posts: | |
| post_id = post['id'] | |
| text = f"{post.get('movie_name', '')} {post.get('content', '')}".strip( | |
| ) | |
| post_texts.append(text) | |
| post_ids.append(post_id) | |
| post['created_at'] = parse_datetime(post['created_at']) | |
| post['type'] = 'post' | |
| post_metadata[post_id] = post | |
| if post_texts: | |
| embeddings = sentence_model.encode( | |
| post_texts, show_progress_bar=False, convert_to_numpy=True) | |
| for post_id, embedding in zip(post_ids, embeddings): | |
| post_features[post_id] = embedding / np.linalg.norm(embedding) | |
| # existing_post_ids = await self.fetch_existing_post_ids() | |
| # self.clean_deleted_posts(existing_post_ids) | |
| self.last_update = datetime.now(TIMEZONE) | |
| total_interactions = len(likes) + len(comments) + \ | |
| len(commentlikes) + len(replies) + len(replylikes) | |
| logger.info( | |
| f"Data updated: {len(posts)} posts, {total_interactions} interactions") | |
| def get_recommendations(self, user_id: str) -> List[Dict]: | |
| user_profile = self.user_profiles[user_id] | |
| seen_posts = user_interactions[user_id] | |
| scores = {} | |
| now = datetime.now(TIMEZONE) | |
| for post_id, feature in post_features.items(): | |
| post = post_metadata[post_id] | |
| if post_id in seen_posts or post.get("author") == user_id: | |
| continue | |
| sim_score = np.dot(user_profile, feature) if np.any( | |
| user_profile) else 0 | |
| time_diff = now - post_metadata[post_id]['created_at'] | |
| # Freshness is defined as an exponential decay based on three days old | |
| freshness = math.exp(-time_diff.days / 3.0) | |
| score = compute_score( | |
| sim_score, self.post_popularity[post_id], freshness) | |
| # Adding a small random noise for exploration | |
| scores[post_id] = score + random.uniform(-0.1, 0.1) | |
| top_posts = sorted( | |
| scores.items(), key=lambda x: x[1], reverse=True)[:TOP_K] | |
| results = [post_metadata[post_id] for post_id, _ in top_posts] | |
| random.shuffle(results) | |
| return results | |
| recommender = Recommender() | |
| scheduler = BackgroundScheduler(timezone="UTC") | |
| async def background_update(): | |
| await recommender.update_data() | |
| def sync_background_update(): | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| loop.run_until_complete(background_update()) | |
| loop.close() | |
| # Lifespan context manager | |
| async def lifespan(app: FastAPI): | |
| # Startup event | |
| logger.info("Starting up application...") | |
| await recommender.update_data() | |
| scheduler.add_job(sync_background_update, 'interval', seconds=UPDATE_INTERVAL) | |
| scheduler.start() | |
| yield | |
| # Shutdown event | |
| logger.info("Shutting down application...") | |
| scheduler.shutdown() | |
| logger.info("Scheduler shut down") | |
| # FastAPI app setup with lifespan | |
| app = FastAPI(lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def get_recommendations_handler(user_id: str = Query(...)): | |
| try: | |
| recommendations = recommender.get_recommendations(user_id) | |
| if recommendations: | |
| if not any(item.get("type") == "suggestedaccounts" for item in recommendations): | |
| insert_pos = random.randint(0, min(9, len(recommendations) - 1)) | |
| recommendations.insert(insert_pos, {"type": "suggestedaccounts"}) | |
| ad_frequency = 10 | |
| i = ad_frequency | |
| while i < len(recommendations): | |
| recommendations.insert(i, {"type": "ad"}) | |
| i += ad_frequency + 1 # Adjust for inserted item | |
| return {"status": "success", "recommendations": recommendations} | |
| except Exception as e: | |
| logger.error(f"Error generating recommendations: {str(e)}") | |
| return {"status": "error", "message": str(e)} | |
| async def health_check(): | |
| return {"status": "success", "message": "Service operational"} |