claims-env / server /plaid_client.py
akhiilll's picture
Deploy ClaimSense adjudication gym
1cfeb15 verified
"""Production-grade Plaid client for purchase verification.
This module is the *real* counterpart to ``plaid_mock.BankProbeStub`` —
it speaks to the genuine Plaid API and surfaces a ``LedgerHit`` shaped
identically to the mock. The gym swaps between them at construction
time when ``PLAID_CLIENT_ID`` / ``PLAID_SECRET`` are populated.
Setup
=====
1. ``pip install plaid-python``
2. Set environment variables before starting the Space::
export PLAID_CLIENT_ID=...
export PLAID_SECRET=...
export PLAID_ENV=sandbox # or development / production
3. Drive the Plaid Link UI on the front-end to obtain a public token,
then exchange it once via :meth:`PlaidGateway.exchange_public_token`.
Keep the resulting ``access_token`` per-claimant.
The only public method the gym calls is :meth:`PlaidGateway.verify_purchase`.
Everything else is Plaid plumbing kept here so the gym never has to
know about Plaid SDK types.
"""
from __future__ import annotations
import os
from dataclasses import dataclass
from datetime import date, datetime, timedelta
from typing import Any
# Plaid SDK is optional at install time — degrade gracefully.
try:
import plaid
from plaid.api import plaid_api
from plaid.model.country_code import CountryCode
from plaid.model.item_public_token_exchange_request import (
ItemPublicTokenExchangeRequest,
)
from plaid.model.link_token_create_request import LinkTokenCreateRequest
from plaid.model.link_token_create_request_user import LinkTokenCreateRequestUser
from plaid.model.products import Products
from plaid.model.transactions_get_request import TransactionsGetRequest
from plaid.model.transactions_get_request_options import (
TransactionsGetRequestOptions,
)
from plaid.model.transactions_sync_request import TransactionsSyncRequest
PLAID_AVAILABLE = True
except ImportError: # pragma: no cover — dev path
plaid = None # type: ignore[assignment]
PLAID_AVAILABLE = False
# ---------------------------------------------------------------------------
# Result type — mirrors plaid_mock.LedgerHit
# ---------------------------------------------------------------------------
@dataclass
class LedgerHit:
"""Outcome of one ``verify_purchase`` call."""
found: bool
transaction_id: str
amount: float
date: str
merchant: str
category: str
confidence: float
discrepancy: bool
discrepancy_reason: str | None
# Backwards-compat alias.
TransactionMatch = LedgerHit
# ---------------------------------------------------------------------------
# Environment selection
# ---------------------------------------------------------------------------
def _resolve_environment(name: str) -> Any:
"""Translate a string label into a Plaid SDK ``Environment`` enum."""
if not PLAID_AVAILABLE:
raise ImportError(
"plaid-python is not installed. Run `pip install plaid-python`."
)
candidates = {
"sandbox": plaid.Environment.Sandbox,
"development": plaid.Environment.Development,
"production": plaid.Environment.Production,
}
return candidates.get(name.lower(), plaid.Environment.Sandbox)
# ---------------------------------------------------------------------------
# Gateway
# ---------------------------------------------------------------------------
class PlaidGateway:
"""Thin wrapper around ``plaid_api.PlaidApi`` tailored to claims work.
Lifecycle::
gateway = PlaidGateway() # reads creds from env vars
link_token = gateway.create_link_token("user-42")
# … browser-side Plaid Link returns a public_token …
access_token = gateway.exchange_public_token(public_token)
hit = gateway.verify_purchase(
access_token=access_token,
claimed_amount=3500.0,
claimed_date="2024-03-01",
claimed_description="Auto repair",
)
"""
DEFAULT_TOLERANCE = 0.15
DEFAULT_DATE_WINDOW_DAYS = 30
AMOUNT_WEIGHT = 0.5
DATE_WEIGHT = 0.3
DESCRIPTION_WEIGHT = 0.2
MIN_CONFIDENCE = 0.5
PRODUCT_LINK_NAME = "ClaimSense"
def __init__(
self,
client_id: str | None = None,
secret: str | None = None,
environment: str = "sandbox",
) -> None:
if not PLAID_AVAILABLE:
raise ImportError(
"plaid-python is not installed. Run `pip install plaid-python`."
)
self.client_id = client_id or os.environ.get("PLAID_CLIENT_ID")
self.secret = secret or os.environ.get("PLAID_SECRET")
self.environment_name = os.environ.get("PLAID_ENV", environment)
if not self.client_id or not self.secret:
raise ValueError(
"Plaid credentials missing. Set PLAID_CLIENT_ID and "
"PLAID_SECRET environment variables, or pass them to "
"PlaidGateway()."
)
configuration = plaid.Configuration(
host=_resolve_environment(self.environment_name),
api_key={"clientId": self.client_id, "secret": self.secret},
)
self._client = plaid_api.PlaidApi(plaid.ApiClient(configuration))
# ------------------------------------------------------------------
# Plaid Link bootstrap
# ------------------------------------------------------------------
def create_link_token(self, user_id: str) -> str:
"""Mint a Link token used by the front-end Plaid Link widget."""
request = LinkTokenCreateRequest(
user=LinkTokenCreateRequestUser(client_user_id=user_id),
client_name=self.PRODUCT_LINK_NAME,
products=[Products("transactions")],
country_codes=[CountryCode("US")],
language="en",
)
response = self._client.link_token_create(request)
return response["link_token"]
def exchange_public_token(self, public_token: str) -> str:
"""Trade a one-time public token for a long-lived access token."""
request = ItemPublicTokenExchangeRequest(public_token=public_token)
response = self._client.item_public_token_exchange(request)
return response["access_token"]
# ------------------------------------------------------------------
# Transaction retrieval
# ------------------------------------------------------------------
def fetch_transactions(
self,
access_token: str,
start_date: date,
end_date: date,
) -> list[dict[str, Any]]:
"""Return *all* transactions in [start_date, end_date], paginating."""
first = self._client.transactions_get(
TransactionsGetRequest(
access_token=access_token,
start_date=start_date,
end_date=end_date,
)
)
transactions = list(first["transactions"])
total = int(first["total_transactions"])
while len(transactions) < total:
options = TransactionsGetRequestOptions(offset=len(transactions))
page = self._client.transactions_get(
TransactionsGetRequest(
access_token=access_token,
start_date=start_date,
end_date=end_date,
options=options,
)
)
transactions.extend(page["transactions"])
return transactions
def sync_transactions(
self, access_token: str, cursor: str | None = None
) -> dict[str, Any]:
"""Incremental ``/transactions/sync`` wrapper.
Recommended over ``fetch_transactions`` for production — Plaid
returns only the deltas, paginated by ``next_cursor``.
"""
first_request = (
TransactionsSyncRequest(access_token=access_token, cursor=cursor)
if cursor
else TransactionsSyncRequest(access_token=access_token)
)
response = self._client.transactions_sync(first_request)
added = list(response["added"])
modified = list(response["modified"])
removed = list(response["removed"])
while response["has_more"]:
response = self._client.transactions_sync(
TransactionsSyncRequest(
access_token=access_token,
cursor=response["next_cursor"],
)
)
added.extend(response["added"])
modified.extend(response["modified"])
removed.extend(response["removed"])
return {
"added": added,
"modified": modified,
"removed": removed,
"next_cursor": response["next_cursor"],
}
# ------------------------------------------------------------------
# The method the gym actually calls
# ------------------------------------------------------------------
def verify_purchase(
self,
access_token: str,
claimed_amount: float,
claimed_date: str,
claimed_description: str = "",
tolerance: float = DEFAULT_TOLERANCE,
date_range_days: int = DEFAULT_DATE_WINDOW_DAYS,
) -> LedgerHit:
"""Look for the strongest transaction match in a ±N-day window."""
try:
window_centre = datetime.strptime(claimed_date, "%Y-%m-%d").date()
except ValueError as exc:
return _miss(f"Could not parse claimed_date: {exc}")
start = window_centre - timedelta(days=date_range_days)
end = window_centre + timedelta(days=date_range_days)
try:
transactions = self.fetch_transactions(access_token, start, end)
except plaid.ApiException as exc: # type: ignore[attr-defined]
return _miss(f"Plaid API error: {exc.body}")
best_tx, best_score = self._best_match(
transactions=transactions,
claimed_amount=claimed_amount,
claimed_description=claimed_description,
window_centre=window_centre,
window_days=date_range_days,
)
if best_tx is None or best_score < self.MIN_CONFIDENCE:
return _miss("No matching transaction found in bank records")
matched_amount = abs(float(best_tx["amount"]))
diff_pct = abs(matched_amount - claimed_amount) / max(1.0, claimed_amount)
flagged = diff_pct > tolerance
return LedgerHit(
found=True,
transaction_id=str(best_tx["transaction_id"]),
amount=matched_amount,
date=str(best_tx["date"]),
merchant=str(
best_tx.get("merchant_name") or best_tx.get("name") or "Unknown"
),
category=(
best_tx["category"][0] if best_tx.get("category") else "unknown"
),
confidence=best_score,
discrepancy=flagged,
discrepancy_reason=(
f"Claimed ${claimed_amount:,.2f} but transaction shows "
f"${matched_amount:,.2f}"
if flagged
else None
),
)
# ------------------------------------------------------------------
# Internal scoring helpers
# ------------------------------------------------------------------
def _best_match(
self,
*,
transactions: list[dict[str, Any]],
claimed_amount: float,
claimed_description: str,
window_centre: date,
window_days: int,
) -> tuple[dict[str, Any] | None, float]:
best_tx: dict[str, Any] | None = None
best_score = 0.0
keywords = [
kw for kw in claimed_description.lower().split() if len(kw) > 2
]
for tx in transactions:
score = self._score(
tx=tx,
claimed_amount=claimed_amount,
keywords=keywords,
window_centre=window_centre,
window_days=window_days,
)
if score > best_score:
best_score, best_tx = score, tx
return best_tx, best_score
def _score(
self,
*,
tx: dict[str, Any],
claimed_amount: float,
keywords: list[str],
window_centre: date,
window_days: int,
) -> float:
amount = abs(float(tx["amount"]))
amount_diff = abs(amount - claimed_amount) / max(1.0, claimed_amount)
amount_score = max(0.0, 1.0 - amount_diff)
try:
tx_date = datetime.strptime(str(tx["date"]), "%Y-%m-%d").date()
except (ValueError, TypeError):
tx_date = window_centre
days_diff = abs((tx_date - window_centre).days)
date_score = max(0.0, 1.0 - days_diff / max(1, window_days))
merchant = (tx.get("merchant_name") or tx.get("name") or "").lower()
if keywords:
description_score = (
1.0 if any(kw in merchant for kw in keywords) else 0.5
)
else:
description_score = 0.5
return (
self.AMOUNT_WEIGHT * amount_score
+ self.DATE_WEIGHT * date_score
+ self.DESCRIPTION_WEIGHT * description_score
)
# ---------------------------------------------------------------------------
# Module-level helpers
# ---------------------------------------------------------------------------
def _miss(reason: str) -> LedgerHit:
"""Build a "no match" result with the given explanation."""
return LedgerHit(
found=False,
transaction_id="",
amount=0.0,
date="",
merchant="",
category="",
confidence=0.0,
discrepancy=True,
discrepancy_reason=reason,
)
def get_plaid_gateway() -> "PlaidGateway":
"""Build a configured ``PlaidGateway``; raises if Plaid is unavailable."""
return PlaidGateway()
def summarize_ledger_hit(hit: LedgerHit) -> str:
"""Formatter shared with ``plaid_mock`` for consistent log output."""
if not hit.found:
return f"VERIFICATION FAILED: {hit.discrepancy_reason}"
headline = "DISCREPANCY DETECTED" if hit.discrepancy else "VERIFIED"
line = (
f"{headline}: Transaction found - ${hit.amount:,.2f} at "
f"{hit.merchant} on {hit.date}"
)
if hit.discrepancy:
line += f" | WARNING: {hit.discrepancy_reason}"
return line
# Backwards-compat aliases.
PlaidClient = PlaidGateway
get_plaid_client = get_plaid_gateway
format_verification_result = summarize_ledger_hit
__all__ = [
"LedgerHit",
"PlaidGateway",
"summarize_ledger_hit",
"get_plaid_gateway",
# legacy
"TransactionMatch",
"PlaidClient",
"get_plaid_client",
"format_verification_result",
"PLAID_AVAILABLE",
]