transcripts-api / app.py
shyameati's picture
Validates ticker
3d04e63
import logging
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from datasets import load_dataset, load_from_disk
import numpy as np
import os
app = FastAPI()
# ---------------------------------------------------------
# Logging
# ---------------------------------------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s"
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------
# CORS
# ---------------------------------------------------------
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------
# Dataset caching configuration
# ---------------------------------------------------------
DATASET_NAME = "kurry/sp500_earnings_transcripts"
CACHE_PATH = "/data/hf_dataset" # persistent bucket mount
dataset_cache = None
def load_hf_dataset():
"""
Loads the HF dataset with persistent caching.
- If /data/hf_dataset exists → load from disk (fast, offline)
- Else → download once, save to disk, then load
"""
global dataset_cache
if dataset_cache is not None:
return dataset_cache
if os.path.exists(CACHE_PATH):
logger.info(f"Loading dataset from cache at {CACHE_PATH}")
dataset_cache = load_from_disk(CACHE_PATH)
logger.info(f"Loaded {len(dataset_cache)} rows from cached dataset")
return dataset_cache
logger.info(f"Downloading HF dataset: {DATASET_NAME}")
ds = load_dataset(DATASET_NAME, split="train")
logger.info(f"Saving dataset to cache at {CACHE_PATH}")
ds.save_to_disk(CACHE_PATH)
dataset_cache = ds
logger.info(f"Dataset cached and loaded ({len(ds)} rows)")
return ds
# ---------------------------------------------------------
# JSON-safe conversion
# ---------------------------------------------------------
def to_json_safe(obj):
if isinstance(obj, (np.integer,)):
return int(obj)
if isinstance(obj, (np.floating,)):
return float(obj)
if isinstance(obj, (np.ndarray, list)):
return [to_json_safe(x) for x in obj]
if isinstance(obj, dict):
return {k: to_json_safe(v) for k, v in obj.items()}
return obj
# ---------------------------------------------------------
# Serve index.html
# ---------------------------------------------------------
@app.get("/", response_class=HTMLResponse)
def serve_index():
if not os.path.exists("index.html"):
return "<h1>index.html not found</h1>"
with open("index.html", "r") as f:
return f.read()
# ---------------------------------------------------------
# List all symbols
# ---------------------------------------------------------
@app.get("/tickers")
def get_tickers():
ds = load_hf_dataset()
symbols = sorted(set([s.upper() for s in ds["symbol"]]))
return {"tickers": symbols}
# ---------------------------------------------------------
# Get transcript for a symbol
# ---------------------------------------------------------
@app.get("/transcript/{symbol}")
def get_transcript(symbol: str):
ds = load_hf_dataset()
symbol = symbol.upper()
logger.info(f"Fetching transcript for: {symbol}")
rows = [r for r in ds if r["symbol"].upper() == symbol]
if not rows:
raise HTTPException(status_code=404, detail=f"No transcript found for {symbol}")
safe_rows = [to_json_safe(r) for r in rows]
return {"symbol": symbol, "records": safe_rows}
# ---------------------------------------------------------
# Dataset info (size + columns)
# ---------------------------------------------------------
@app.get("/dataset-info")
def dataset_info():
ds = load_hf_dataset()
info = {
"num_rows": len(ds),
"columns": ds.column_names,
"cache_path": CACHE_PATH,
}
return info
# ---------------------------------------------------------
# Dataset summary (high-level stats)
# ---------------------------------------------------------
@app.get("/dataset-summary")
def dataset_summary():
ds = load_hf_dataset()
symbols = set([s.upper() for s in ds["symbol"]])
years = set(ds["year"])
quarters = set(ds["quarter"])
dates = [d for d in ds["date"] if d is not None]
min_date = min(dates) if dates else None
max_date = max(dates) if dates else None
summary = {
"total_rows": len(ds),
"unique_symbols": len(symbols),
"symbols_sample": sorted(list(symbols))[:20],
"year_range": {
"min_year": min(years),
"max_year": max(years)
},
"quarters_present": sorted(list(quarters)),
"date_range": {
"min_date": min_date,
"max_date": max_date
},
"company_count": len(set(ds["company_id"])),
}
return summary
@app.get("/check/{symbol}")
def check_symbol(symbol: str):
ds = load_hf_dataset()
symbol = symbol.upper()
exists = any(r["symbol"].upper() == symbol for r in ds)
if not exists:
logger.warning(f"Symbol not found: {symbol}")
return {
"symbol": symbol,
"exists": False,
"message": f"Symbol '{symbol}' does not exist in the dataset."
}
logger.info(f"Symbol exists: {symbol}")
return {
"symbol": symbol,
"exists": True,
"message": f"Symbol '{symbol}' exists in the dataset."
}