"""Production-grade Plaid client for purchase verification. This module is the *real* counterpart to ``plaid_mock.BankProbeStub`` — it speaks to the genuine Plaid API and surfaces a ``LedgerHit`` shaped identically to the mock. The gym swaps between them at construction time when ``PLAID_CLIENT_ID`` / ``PLAID_SECRET`` are populated. Setup ===== 1. ``pip install plaid-python`` 2. Set environment variables before starting the Space:: export PLAID_CLIENT_ID=... export PLAID_SECRET=... export PLAID_ENV=sandbox # or development / production 3. Drive the Plaid Link UI on the front-end to obtain a public token, then exchange it once via :meth:`PlaidGateway.exchange_public_token`. Keep the resulting ``access_token`` per-claimant. The only public method the gym calls is :meth:`PlaidGateway.verify_purchase`. Everything else is Plaid plumbing kept here so the gym never has to know about Plaid SDK types. """ from __future__ import annotations import os from dataclasses import dataclass from datetime import date, datetime, timedelta from typing import Any # Plaid SDK is optional at install time — degrade gracefully. try: import plaid from plaid.api import plaid_api from plaid.model.country_code import CountryCode from plaid.model.item_public_token_exchange_request import ( ItemPublicTokenExchangeRequest, ) from plaid.model.link_token_create_request import LinkTokenCreateRequest from plaid.model.link_token_create_request_user import LinkTokenCreateRequestUser from plaid.model.products import Products from plaid.model.transactions_get_request import TransactionsGetRequest from plaid.model.transactions_get_request_options import ( TransactionsGetRequestOptions, ) from plaid.model.transactions_sync_request import TransactionsSyncRequest PLAID_AVAILABLE = True except ImportError: # pragma: no cover — dev path plaid = None # type: ignore[assignment] PLAID_AVAILABLE = False # --------------------------------------------------------------------------- # Result type — mirrors plaid_mock.LedgerHit # --------------------------------------------------------------------------- @dataclass class LedgerHit: """Outcome of one ``verify_purchase`` call.""" found: bool transaction_id: str amount: float date: str merchant: str category: str confidence: float discrepancy: bool discrepancy_reason: str | None # Backwards-compat alias. TransactionMatch = LedgerHit # --------------------------------------------------------------------------- # Environment selection # --------------------------------------------------------------------------- def _resolve_environment(name: str) -> Any: """Translate a string label into a Plaid SDK ``Environment`` enum.""" if not PLAID_AVAILABLE: raise ImportError( "plaid-python is not installed. Run `pip install plaid-python`." ) candidates = { "sandbox": plaid.Environment.Sandbox, "development": plaid.Environment.Development, "production": plaid.Environment.Production, } return candidates.get(name.lower(), plaid.Environment.Sandbox) # --------------------------------------------------------------------------- # Gateway # --------------------------------------------------------------------------- class PlaidGateway: """Thin wrapper around ``plaid_api.PlaidApi`` tailored to claims work. Lifecycle:: gateway = PlaidGateway() # reads creds from env vars link_token = gateway.create_link_token("user-42") # … browser-side Plaid Link returns a public_token … access_token = gateway.exchange_public_token(public_token) hit = gateway.verify_purchase( access_token=access_token, claimed_amount=3500.0, claimed_date="2024-03-01", claimed_description="Auto repair", ) """ DEFAULT_TOLERANCE = 0.15 DEFAULT_DATE_WINDOW_DAYS = 30 AMOUNT_WEIGHT = 0.5 DATE_WEIGHT = 0.3 DESCRIPTION_WEIGHT = 0.2 MIN_CONFIDENCE = 0.5 PRODUCT_LINK_NAME = "ClaimSense" def __init__( self, client_id: str | None = None, secret: str | None = None, environment: str = "sandbox", ) -> None: if not PLAID_AVAILABLE: raise ImportError( "plaid-python is not installed. Run `pip install plaid-python`." ) self.client_id = client_id or os.environ.get("PLAID_CLIENT_ID") self.secret = secret or os.environ.get("PLAID_SECRET") self.environment_name = os.environ.get("PLAID_ENV", environment) if not self.client_id or not self.secret: raise ValueError( "Plaid credentials missing. Set PLAID_CLIENT_ID and " "PLAID_SECRET environment variables, or pass them to " "PlaidGateway()." ) configuration = plaid.Configuration( host=_resolve_environment(self.environment_name), api_key={"clientId": self.client_id, "secret": self.secret}, ) self._client = plaid_api.PlaidApi(plaid.ApiClient(configuration)) # ------------------------------------------------------------------ # Plaid Link bootstrap # ------------------------------------------------------------------ def create_link_token(self, user_id: str) -> str: """Mint a Link token used by the front-end Plaid Link widget.""" request = LinkTokenCreateRequest( user=LinkTokenCreateRequestUser(client_user_id=user_id), client_name=self.PRODUCT_LINK_NAME, products=[Products("transactions")], country_codes=[CountryCode("US")], language="en", ) response = self._client.link_token_create(request) return response["link_token"] def exchange_public_token(self, public_token: str) -> str: """Trade a one-time public token for a long-lived access token.""" request = ItemPublicTokenExchangeRequest(public_token=public_token) response = self._client.item_public_token_exchange(request) return response["access_token"] # ------------------------------------------------------------------ # Transaction retrieval # ------------------------------------------------------------------ def fetch_transactions( self, access_token: str, start_date: date, end_date: date, ) -> list[dict[str, Any]]: """Return *all* transactions in [start_date, end_date], paginating.""" first = self._client.transactions_get( TransactionsGetRequest( access_token=access_token, start_date=start_date, end_date=end_date, ) ) transactions = list(first["transactions"]) total = int(first["total_transactions"]) while len(transactions) < total: options = TransactionsGetRequestOptions(offset=len(transactions)) page = self._client.transactions_get( TransactionsGetRequest( access_token=access_token, start_date=start_date, end_date=end_date, options=options, ) ) transactions.extend(page["transactions"]) return transactions def sync_transactions( self, access_token: str, cursor: str | None = None ) -> dict[str, Any]: """Incremental ``/transactions/sync`` wrapper. Recommended over ``fetch_transactions`` for production — Plaid returns only the deltas, paginated by ``next_cursor``. """ first_request = ( TransactionsSyncRequest(access_token=access_token, cursor=cursor) if cursor else TransactionsSyncRequest(access_token=access_token) ) response = self._client.transactions_sync(first_request) added = list(response["added"]) modified = list(response["modified"]) removed = list(response["removed"]) while response["has_more"]: response = self._client.transactions_sync( TransactionsSyncRequest( access_token=access_token, cursor=response["next_cursor"], ) ) added.extend(response["added"]) modified.extend(response["modified"]) removed.extend(response["removed"]) return { "added": added, "modified": modified, "removed": removed, "next_cursor": response["next_cursor"], } # ------------------------------------------------------------------ # The method the gym actually calls # ------------------------------------------------------------------ def verify_purchase( self, access_token: str, claimed_amount: float, claimed_date: str, claimed_description: str = "", tolerance: float = DEFAULT_TOLERANCE, date_range_days: int = DEFAULT_DATE_WINDOW_DAYS, ) -> LedgerHit: """Look for the strongest transaction match in a ±N-day window.""" try: window_centre = datetime.strptime(claimed_date, "%Y-%m-%d").date() except ValueError as exc: return _miss(f"Could not parse claimed_date: {exc}") start = window_centre - timedelta(days=date_range_days) end = window_centre + timedelta(days=date_range_days) try: transactions = self.fetch_transactions(access_token, start, end) except plaid.ApiException as exc: # type: ignore[attr-defined] return _miss(f"Plaid API error: {exc.body}") best_tx, best_score = self._best_match( transactions=transactions, claimed_amount=claimed_amount, claimed_description=claimed_description, window_centre=window_centre, window_days=date_range_days, ) if best_tx is None or best_score < self.MIN_CONFIDENCE: return _miss("No matching transaction found in bank records") matched_amount = abs(float(best_tx["amount"])) diff_pct = abs(matched_amount - claimed_amount) / max(1.0, claimed_amount) flagged = diff_pct > tolerance return LedgerHit( found=True, transaction_id=str(best_tx["transaction_id"]), amount=matched_amount, date=str(best_tx["date"]), merchant=str( best_tx.get("merchant_name") or best_tx.get("name") or "Unknown" ), category=( best_tx["category"][0] if best_tx.get("category") else "unknown" ), confidence=best_score, discrepancy=flagged, discrepancy_reason=( f"Claimed ${claimed_amount:,.2f} but transaction shows " f"${matched_amount:,.2f}" if flagged else None ), ) # ------------------------------------------------------------------ # Internal scoring helpers # ------------------------------------------------------------------ def _best_match( self, *, transactions: list[dict[str, Any]], claimed_amount: float, claimed_description: str, window_centre: date, window_days: int, ) -> tuple[dict[str, Any] | None, float]: best_tx: dict[str, Any] | None = None best_score = 0.0 keywords = [ kw for kw in claimed_description.lower().split() if len(kw) > 2 ] for tx in transactions: score = self._score( tx=tx, claimed_amount=claimed_amount, keywords=keywords, window_centre=window_centre, window_days=window_days, ) if score > best_score: best_score, best_tx = score, tx return best_tx, best_score def _score( self, *, tx: dict[str, Any], claimed_amount: float, keywords: list[str], window_centre: date, window_days: int, ) -> float: amount = abs(float(tx["amount"])) amount_diff = abs(amount - claimed_amount) / max(1.0, claimed_amount) amount_score = max(0.0, 1.0 - amount_diff) try: tx_date = datetime.strptime(str(tx["date"]), "%Y-%m-%d").date() except (ValueError, TypeError): tx_date = window_centre days_diff = abs((tx_date - window_centre).days) date_score = max(0.0, 1.0 - days_diff / max(1, window_days)) merchant = (tx.get("merchant_name") or tx.get("name") or "").lower() if keywords: description_score = ( 1.0 if any(kw in merchant for kw in keywords) else 0.5 ) else: description_score = 0.5 return ( self.AMOUNT_WEIGHT * amount_score + self.DATE_WEIGHT * date_score + self.DESCRIPTION_WEIGHT * description_score ) # --------------------------------------------------------------------------- # Module-level helpers # --------------------------------------------------------------------------- def _miss(reason: str) -> LedgerHit: """Build a "no match" result with the given explanation.""" return LedgerHit( found=False, transaction_id="", amount=0.0, date="", merchant="", category="", confidence=0.0, discrepancy=True, discrepancy_reason=reason, ) def get_plaid_gateway() -> "PlaidGateway": """Build a configured ``PlaidGateway``; raises if Plaid is unavailable.""" return PlaidGateway() def summarize_ledger_hit(hit: LedgerHit) -> str: """Formatter shared with ``plaid_mock`` for consistent log output.""" if not hit.found: return f"VERIFICATION FAILED: {hit.discrepancy_reason}" headline = "DISCREPANCY DETECTED" if hit.discrepancy else "VERIFIED" line = ( f"{headline}: Transaction found - ${hit.amount:,.2f} at " f"{hit.merchant} on {hit.date}" ) if hit.discrepancy: line += f" | WARNING: {hit.discrepancy_reason}" return line # Backwards-compat aliases. PlaidClient = PlaidGateway get_plaid_client = get_plaid_gateway format_verification_result = summarize_ledger_hit __all__ = [ "LedgerHit", "PlaidGateway", "summarize_ledger_hit", "get_plaid_gateway", # legacy "TransactionMatch", "PlaidClient", "get_plaid_client", "format_verification_result", "PLAID_AVAILABLE", ]