import logging from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse from fastapi.middleware.cors import CORSMiddleware from datasets import load_dataset, load_from_disk import numpy as np import os app = FastAPI() # --------------------------------------------------------- # Logging # --------------------------------------------------------- logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s" ) logger = logging.getLogger(__name__) # --------------------------------------------------------- # CORS # --------------------------------------------------------- app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # --------------------------------------------------------- # Dataset caching configuration # --------------------------------------------------------- DATASET_NAME = "kurry/sp500_earnings_transcripts" CACHE_PATH = "/data/hf_dataset" # persistent bucket mount dataset_cache = None def load_hf_dataset(): """ Loads the HF dataset with persistent caching. - If /data/hf_dataset exists → load from disk (fast, offline) - Else → download once, save to disk, then load """ global dataset_cache if dataset_cache is not None: return dataset_cache if os.path.exists(CACHE_PATH): logger.info(f"Loading dataset from cache at {CACHE_PATH}") dataset_cache = load_from_disk(CACHE_PATH) logger.info(f"Loaded {len(dataset_cache)} rows from cached dataset") return dataset_cache logger.info(f"Downloading HF dataset: {DATASET_NAME}") ds = load_dataset(DATASET_NAME, split="train") logger.info(f"Saving dataset to cache at {CACHE_PATH}") ds.save_to_disk(CACHE_PATH) dataset_cache = ds logger.info(f"Dataset cached and loaded ({len(ds)} rows)") return ds # --------------------------------------------------------- # JSON-safe conversion # --------------------------------------------------------- def to_json_safe(obj): if isinstance(obj, (np.integer,)): return int(obj) if isinstance(obj, (np.floating,)): return float(obj) if isinstance(obj, (np.ndarray, list)): return [to_json_safe(x) for x in obj] if isinstance(obj, dict): return {k: to_json_safe(v) for k, v in obj.items()} return obj # --------------------------------------------------------- # Serve index.html # --------------------------------------------------------- @app.get("/", response_class=HTMLResponse) def serve_index(): if not os.path.exists("index.html"): return "

index.html not found

" with open("index.html", "r") as f: return f.read() # --------------------------------------------------------- # List all symbols # --------------------------------------------------------- @app.get("/tickers") def get_tickers(): ds = load_hf_dataset() symbols = sorted(set([s.upper() for s in ds["symbol"]])) return {"tickers": symbols} # --------------------------------------------------------- # Get transcript for a symbol # --------------------------------------------------------- @app.get("/transcript/{symbol}") def get_transcript(symbol: str): ds = load_hf_dataset() symbol = symbol.upper() logger.info(f"Fetching transcript for: {symbol}") rows = [r for r in ds if r["symbol"].upper() == symbol] if not rows: raise HTTPException(status_code=404, detail=f"No transcript found for {symbol}") safe_rows = [to_json_safe(r) for r in rows] return {"symbol": symbol, "records": safe_rows} # --------------------------------------------------------- # Dataset info (size + columns) # --------------------------------------------------------- @app.get("/dataset-info") def dataset_info(): ds = load_hf_dataset() info = { "num_rows": len(ds), "columns": ds.column_names, "cache_path": CACHE_PATH, } return info # --------------------------------------------------------- # Dataset summary (high-level stats) # --------------------------------------------------------- @app.get("/dataset-summary") def dataset_summary(): ds = load_hf_dataset() symbols = set([s.upper() for s in ds["symbol"]]) years = set(ds["year"]) quarters = set(ds["quarter"]) dates = [d for d in ds["date"] if d is not None] min_date = min(dates) if dates else None max_date = max(dates) if dates else None summary = { "total_rows": len(ds), "unique_symbols": len(symbols), "symbols_sample": sorted(list(symbols))[:20], "year_range": { "min_year": min(years), "max_year": max(years) }, "quarters_present": sorted(list(quarters)), "date_range": { "min_date": min_date, "max_date": max_date }, "company_count": len(set(ds["company_id"])), } return summary @app.get("/check/{symbol}") def check_symbol(symbol: str): ds = load_hf_dataset() symbol = symbol.upper() exists = any(r["symbol"].upper() == symbol for r in ds) if not exists: logger.warning(f"Symbol not found: {symbol}") return { "symbol": symbol, "exists": False, "message": f"Symbol '{symbol}' does not exist in the dataset." } logger.info(f"Symbol exists: {symbol}") return { "symbol": symbol, "exists": True, "message": f"Symbol '{symbol}' exists in the dataset." }