Spaces:
Sleeping
Sleeping
| import logging | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from datasets import load_dataset, load_from_disk | |
| import numpy as np | |
| import os | |
| app = FastAPI() | |
| # --------------------------------------------------------- | |
| # Logging | |
| # --------------------------------------------------------- | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------- | |
| # CORS | |
| # --------------------------------------------------------- | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --------------------------------------------------------- | |
| # Dataset caching configuration | |
| # --------------------------------------------------------- | |
| DATASET_NAME = "kurry/sp500_earnings_transcripts" | |
| CACHE_PATH = "/data/hf_dataset" # persistent bucket mount | |
| dataset_cache = None | |
| def load_hf_dataset(): | |
| """ | |
| Loads the HF dataset with persistent caching. | |
| - If /data/hf_dataset exists → load from disk (fast, offline) | |
| - Else → download once, save to disk, then load | |
| """ | |
| global dataset_cache | |
| if dataset_cache is not None: | |
| return dataset_cache | |
| if os.path.exists(CACHE_PATH): | |
| logger.info(f"Loading dataset from cache at {CACHE_PATH}") | |
| dataset_cache = load_from_disk(CACHE_PATH) | |
| logger.info(f"Loaded {len(dataset_cache)} rows from cached dataset") | |
| return dataset_cache | |
| logger.info(f"Downloading HF dataset: {DATASET_NAME}") | |
| ds = load_dataset(DATASET_NAME, split="train") | |
| logger.info(f"Saving dataset to cache at {CACHE_PATH}") | |
| ds.save_to_disk(CACHE_PATH) | |
| dataset_cache = ds | |
| logger.info(f"Dataset cached and loaded ({len(ds)} rows)") | |
| return ds | |
| # --------------------------------------------------------- | |
| # JSON-safe conversion | |
| # --------------------------------------------------------- | |
| def to_json_safe(obj): | |
| if isinstance(obj, (np.integer,)): | |
| return int(obj) | |
| if isinstance(obj, (np.floating,)): | |
| return float(obj) | |
| if isinstance(obj, (np.ndarray, list)): | |
| return [to_json_safe(x) for x in obj] | |
| if isinstance(obj, dict): | |
| return {k: to_json_safe(v) for k, v in obj.items()} | |
| return obj | |
| # --------------------------------------------------------- | |
| # Serve index.html | |
| # --------------------------------------------------------- | |
| def serve_index(): | |
| if not os.path.exists("index.html"): | |
| return "<h1>index.html not found</h1>" | |
| with open("index.html", "r") as f: | |
| return f.read() | |
| # --------------------------------------------------------- | |
| # List all symbols | |
| # --------------------------------------------------------- | |
| def get_tickers(): | |
| ds = load_hf_dataset() | |
| symbols = sorted(set([s.upper() for s in ds["symbol"]])) | |
| return {"tickers": symbols} | |
| # --------------------------------------------------------- | |
| # Get transcript for a symbol | |
| # --------------------------------------------------------- | |
| def get_transcript(symbol: str): | |
| ds = load_hf_dataset() | |
| symbol = symbol.upper() | |
| logger.info(f"Fetching transcript for: {symbol}") | |
| rows = [r for r in ds if r["symbol"].upper() == symbol] | |
| if not rows: | |
| raise HTTPException(status_code=404, detail=f"No transcript found for {symbol}") | |
| safe_rows = [to_json_safe(r) for r in rows] | |
| return {"symbol": symbol, "records": safe_rows} | |
| # --------------------------------------------------------- | |
| # Dataset info (size + columns) | |
| # --------------------------------------------------------- | |
| def dataset_info(): | |
| ds = load_hf_dataset() | |
| info = { | |
| "num_rows": len(ds), | |
| "columns": ds.column_names, | |
| "cache_path": CACHE_PATH, | |
| } | |
| return info | |
| # --------------------------------------------------------- | |
| # Dataset summary (high-level stats) | |
| # --------------------------------------------------------- | |
| def dataset_summary(): | |
| ds = load_hf_dataset() | |
| symbols = set([s.upper() for s in ds["symbol"]]) | |
| years = set(ds["year"]) | |
| quarters = set(ds["quarter"]) | |
| dates = [d for d in ds["date"] if d is not None] | |
| min_date = min(dates) if dates else None | |
| max_date = max(dates) if dates else None | |
| summary = { | |
| "total_rows": len(ds), | |
| "unique_symbols": len(symbols), | |
| "symbols_sample": sorted(list(symbols))[:20], | |
| "year_range": { | |
| "min_year": min(years), | |
| "max_year": max(years) | |
| }, | |
| "quarters_present": sorted(list(quarters)), | |
| "date_range": { | |
| "min_date": min_date, | |
| "max_date": max_date | |
| }, | |
| "company_count": len(set(ds["company_id"])), | |
| } | |
| return summary | |
| def check_symbol(symbol: str): | |
| ds = load_hf_dataset() | |
| symbol = symbol.upper() | |
| exists = any(r["symbol"].upper() == symbol for r in ds) | |
| if not exists: | |
| logger.warning(f"Symbol not found: {symbol}") | |
| return { | |
| "symbol": symbol, | |
| "exists": False, | |
| "message": f"Symbol '{symbol}' does not exist in the dataset." | |
| } | |
| logger.info(f"Symbol exists: {symbol}") | |
| return { | |
| "symbol": symbol, | |
| "exists": True, | |
| "message": f"Symbol '{symbol}' exists in the dataset." | |
| } | |