Spaces:
Sleeping
Sleeping
| """Production-grade Plaid client for purchase verification. | |
| This module is the *real* counterpart to ``plaid_mock.BankProbeStub`` — | |
| it speaks to the genuine Plaid API and surfaces a ``LedgerHit`` shaped | |
| identically to the mock. The gym swaps between them at construction | |
| time when ``PLAID_CLIENT_ID`` / ``PLAID_SECRET`` are populated. | |
| Setup | |
| ===== | |
| 1. ``pip install plaid-python`` | |
| 2. Set environment variables before starting the Space:: | |
| export PLAID_CLIENT_ID=... | |
| export PLAID_SECRET=... | |
| export PLAID_ENV=sandbox # or development / production | |
| 3. Drive the Plaid Link UI on the front-end to obtain a public token, | |
| then exchange it once via :meth:`PlaidGateway.exchange_public_token`. | |
| Keep the resulting ``access_token`` per-claimant. | |
| The only public method the gym calls is :meth:`PlaidGateway.verify_purchase`. | |
| Everything else is Plaid plumbing kept here so the gym never has to | |
| know about Plaid SDK types. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| from dataclasses import dataclass | |
| from datetime import date, datetime, timedelta | |
| from typing import Any | |
| # Plaid SDK is optional at install time — degrade gracefully. | |
| try: | |
| import plaid | |
| from plaid.api import plaid_api | |
| from plaid.model.country_code import CountryCode | |
| from plaid.model.item_public_token_exchange_request import ( | |
| ItemPublicTokenExchangeRequest, | |
| ) | |
| from plaid.model.link_token_create_request import LinkTokenCreateRequest | |
| from plaid.model.link_token_create_request_user import LinkTokenCreateRequestUser | |
| from plaid.model.products import Products | |
| from plaid.model.transactions_get_request import TransactionsGetRequest | |
| from plaid.model.transactions_get_request_options import ( | |
| TransactionsGetRequestOptions, | |
| ) | |
| from plaid.model.transactions_sync_request import TransactionsSyncRequest | |
| PLAID_AVAILABLE = True | |
| except ImportError: # pragma: no cover — dev path | |
| plaid = None # type: ignore[assignment] | |
| PLAID_AVAILABLE = False | |
| # --------------------------------------------------------------------------- | |
| # Result type — mirrors plaid_mock.LedgerHit | |
| # --------------------------------------------------------------------------- | |
| class LedgerHit: | |
| """Outcome of one ``verify_purchase`` call.""" | |
| found: bool | |
| transaction_id: str | |
| amount: float | |
| date: str | |
| merchant: str | |
| category: str | |
| confidence: float | |
| discrepancy: bool | |
| discrepancy_reason: str | None | |
| # Backwards-compat alias. | |
| TransactionMatch = LedgerHit | |
| # --------------------------------------------------------------------------- | |
| # Environment selection | |
| # --------------------------------------------------------------------------- | |
| def _resolve_environment(name: str) -> Any: | |
| """Translate a string label into a Plaid SDK ``Environment`` enum.""" | |
| if not PLAID_AVAILABLE: | |
| raise ImportError( | |
| "plaid-python is not installed. Run `pip install plaid-python`." | |
| ) | |
| candidates = { | |
| "sandbox": plaid.Environment.Sandbox, | |
| "development": plaid.Environment.Development, | |
| "production": plaid.Environment.Production, | |
| } | |
| return candidates.get(name.lower(), plaid.Environment.Sandbox) | |
| # --------------------------------------------------------------------------- | |
| # Gateway | |
| # --------------------------------------------------------------------------- | |
| class PlaidGateway: | |
| """Thin wrapper around ``plaid_api.PlaidApi`` tailored to claims work. | |
| Lifecycle:: | |
| gateway = PlaidGateway() # reads creds from env vars | |
| link_token = gateway.create_link_token("user-42") | |
| # … browser-side Plaid Link returns a public_token … | |
| access_token = gateway.exchange_public_token(public_token) | |
| hit = gateway.verify_purchase( | |
| access_token=access_token, | |
| claimed_amount=3500.0, | |
| claimed_date="2024-03-01", | |
| claimed_description="Auto repair", | |
| ) | |
| """ | |
| DEFAULT_TOLERANCE = 0.15 | |
| DEFAULT_DATE_WINDOW_DAYS = 30 | |
| AMOUNT_WEIGHT = 0.5 | |
| DATE_WEIGHT = 0.3 | |
| DESCRIPTION_WEIGHT = 0.2 | |
| MIN_CONFIDENCE = 0.5 | |
| PRODUCT_LINK_NAME = "ClaimSense" | |
| def __init__( | |
| self, | |
| client_id: str | None = None, | |
| secret: str | None = None, | |
| environment: str = "sandbox", | |
| ) -> None: | |
| if not PLAID_AVAILABLE: | |
| raise ImportError( | |
| "plaid-python is not installed. Run `pip install plaid-python`." | |
| ) | |
| self.client_id = client_id or os.environ.get("PLAID_CLIENT_ID") | |
| self.secret = secret or os.environ.get("PLAID_SECRET") | |
| self.environment_name = os.environ.get("PLAID_ENV", environment) | |
| if not self.client_id or not self.secret: | |
| raise ValueError( | |
| "Plaid credentials missing. Set PLAID_CLIENT_ID and " | |
| "PLAID_SECRET environment variables, or pass them to " | |
| "PlaidGateway()." | |
| ) | |
| configuration = plaid.Configuration( | |
| host=_resolve_environment(self.environment_name), | |
| api_key={"clientId": self.client_id, "secret": self.secret}, | |
| ) | |
| self._client = plaid_api.PlaidApi(plaid.ApiClient(configuration)) | |
| # ------------------------------------------------------------------ | |
| # Plaid Link bootstrap | |
| # ------------------------------------------------------------------ | |
| def create_link_token(self, user_id: str) -> str: | |
| """Mint a Link token used by the front-end Plaid Link widget.""" | |
| request = LinkTokenCreateRequest( | |
| user=LinkTokenCreateRequestUser(client_user_id=user_id), | |
| client_name=self.PRODUCT_LINK_NAME, | |
| products=[Products("transactions")], | |
| country_codes=[CountryCode("US")], | |
| language="en", | |
| ) | |
| response = self._client.link_token_create(request) | |
| return response["link_token"] | |
| def exchange_public_token(self, public_token: str) -> str: | |
| """Trade a one-time public token for a long-lived access token.""" | |
| request = ItemPublicTokenExchangeRequest(public_token=public_token) | |
| response = self._client.item_public_token_exchange(request) | |
| return response["access_token"] | |
| # ------------------------------------------------------------------ | |
| # Transaction retrieval | |
| # ------------------------------------------------------------------ | |
| def fetch_transactions( | |
| self, | |
| access_token: str, | |
| start_date: date, | |
| end_date: date, | |
| ) -> list[dict[str, Any]]: | |
| """Return *all* transactions in [start_date, end_date], paginating.""" | |
| first = self._client.transactions_get( | |
| TransactionsGetRequest( | |
| access_token=access_token, | |
| start_date=start_date, | |
| end_date=end_date, | |
| ) | |
| ) | |
| transactions = list(first["transactions"]) | |
| total = int(first["total_transactions"]) | |
| while len(transactions) < total: | |
| options = TransactionsGetRequestOptions(offset=len(transactions)) | |
| page = self._client.transactions_get( | |
| TransactionsGetRequest( | |
| access_token=access_token, | |
| start_date=start_date, | |
| end_date=end_date, | |
| options=options, | |
| ) | |
| ) | |
| transactions.extend(page["transactions"]) | |
| return transactions | |
| def sync_transactions( | |
| self, access_token: str, cursor: str | None = None | |
| ) -> dict[str, Any]: | |
| """Incremental ``/transactions/sync`` wrapper. | |
| Recommended over ``fetch_transactions`` for production — Plaid | |
| returns only the deltas, paginated by ``next_cursor``. | |
| """ | |
| first_request = ( | |
| TransactionsSyncRequest(access_token=access_token, cursor=cursor) | |
| if cursor | |
| else TransactionsSyncRequest(access_token=access_token) | |
| ) | |
| response = self._client.transactions_sync(first_request) | |
| added = list(response["added"]) | |
| modified = list(response["modified"]) | |
| removed = list(response["removed"]) | |
| while response["has_more"]: | |
| response = self._client.transactions_sync( | |
| TransactionsSyncRequest( | |
| access_token=access_token, | |
| cursor=response["next_cursor"], | |
| ) | |
| ) | |
| added.extend(response["added"]) | |
| modified.extend(response["modified"]) | |
| removed.extend(response["removed"]) | |
| return { | |
| "added": added, | |
| "modified": modified, | |
| "removed": removed, | |
| "next_cursor": response["next_cursor"], | |
| } | |
| # ------------------------------------------------------------------ | |
| # The method the gym actually calls | |
| # ------------------------------------------------------------------ | |
| def verify_purchase( | |
| self, | |
| access_token: str, | |
| claimed_amount: float, | |
| claimed_date: str, | |
| claimed_description: str = "", | |
| tolerance: float = DEFAULT_TOLERANCE, | |
| date_range_days: int = DEFAULT_DATE_WINDOW_DAYS, | |
| ) -> LedgerHit: | |
| """Look for the strongest transaction match in a ±N-day window.""" | |
| try: | |
| window_centre = datetime.strptime(claimed_date, "%Y-%m-%d").date() | |
| except ValueError as exc: | |
| return _miss(f"Could not parse claimed_date: {exc}") | |
| start = window_centre - timedelta(days=date_range_days) | |
| end = window_centre + timedelta(days=date_range_days) | |
| try: | |
| transactions = self.fetch_transactions(access_token, start, end) | |
| except plaid.ApiException as exc: # type: ignore[attr-defined] | |
| return _miss(f"Plaid API error: {exc.body}") | |
| best_tx, best_score = self._best_match( | |
| transactions=transactions, | |
| claimed_amount=claimed_amount, | |
| claimed_description=claimed_description, | |
| window_centre=window_centre, | |
| window_days=date_range_days, | |
| ) | |
| if best_tx is None or best_score < self.MIN_CONFIDENCE: | |
| return _miss("No matching transaction found in bank records") | |
| matched_amount = abs(float(best_tx["amount"])) | |
| diff_pct = abs(matched_amount - claimed_amount) / max(1.0, claimed_amount) | |
| flagged = diff_pct > tolerance | |
| return LedgerHit( | |
| found=True, | |
| transaction_id=str(best_tx["transaction_id"]), | |
| amount=matched_amount, | |
| date=str(best_tx["date"]), | |
| merchant=str( | |
| best_tx.get("merchant_name") or best_tx.get("name") or "Unknown" | |
| ), | |
| category=( | |
| best_tx["category"][0] if best_tx.get("category") else "unknown" | |
| ), | |
| confidence=best_score, | |
| discrepancy=flagged, | |
| discrepancy_reason=( | |
| f"Claimed ${claimed_amount:,.2f} but transaction shows " | |
| f"${matched_amount:,.2f}" | |
| if flagged | |
| else None | |
| ), | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Internal scoring helpers | |
| # ------------------------------------------------------------------ | |
| def _best_match( | |
| self, | |
| *, | |
| transactions: list[dict[str, Any]], | |
| claimed_amount: float, | |
| claimed_description: str, | |
| window_centre: date, | |
| window_days: int, | |
| ) -> tuple[dict[str, Any] | None, float]: | |
| best_tx: dict[str, Any] | None = None | |
| best_score = 0.0 | |
| keywords = [ | |
| kw for kw in claimed_description.lower().split() if len(kw) > 2 | |
| ] | |
| for tx in transactions: | |
| score = self._score( | |
| tx=tx, | |
| claimed_amount=claimed_amount, | |
| keywords=keywords, | |
| window_centre=window_centre, | |
| window_days=window_days, | |
| ) | |
| if score > best_score: | |
| best_score, best_tx = score, tx | |
| return best_tx, best_score | |
| def _score( | |
| self, | |
| *, | |
| tx: dict[str, Any], | |
| claimed_amount: float, | |
| keywords: list[str], | |
| window_centre: date, | |
| window_days: int, | |
| ) -> float: | |
| amount = abs(float(tx["amount"])) | |
| amount_diff = abs(amount - claimed_amount) / max(1.0, claimed_amount) | |
| amount_score = max(0.0, 1.0 - amount_diff) | |
| try: | |
| tx_date = datetime.strptime(str(tx["date"]), "%Y-%m-%d").date() | |
| except (ValueError, TypeError): | |
| tx_date = window_centre | |
| days_diff = abs((tx_date - window_centre).days) | |
| date_score = max(0.0, 1.0 - days_diff / max(1, window_days)) | |
| merchant = (tx.get("merchant_name") or tx.get("name") or "").lower() | |
| if keywords: | |
| description_score = ( | |
| 1.0 if any(kw in merchant for kw in keywords) else 0.5 | |
| ) | |
| else: | |
| description_score = 0.5 | |
| return ( | |
| self.AMOUNT_WEIGHT * amount_score | |
| + self.DATE_WEIGHT * date_score | |
| + self.DESCRIPTION_WEIGHT * description_score | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Module-level helpers | |
| # --------------------------------------------------------------------------- | |
| def _miss(reason: str) -> LedgerHit: | |
| """Build a "no match" result with the given explanation.""" | |
| return LedgerHit( | |
| found=False, | |
| transaction_id="", | |
| amount=0.0, | |
| date="", | |
| merchant="", | |
| category="", | |
| confidence=0.0, | |
| discrepancy=True, | |
| discrepancy_reason=reason, | |
| ) | |
| def get_plaid_gateway() -> "PlaidGateway": | |
| """Build a configured ``PlaidGateway``; raises if Plaid is unavailable.""" | |
| return PlaidGateway() | |
| def summarize_ledger_hit(hit: LedgerHit) -> str: | |
| """Formatter shared with ``plaid_mock`` for consistent log output.""" | |
| if not hit.found: | |
| return f"VERIFICATION FAILED: {hit.discrepancy_reason}" | |
| headline = "DISCREPANCY DETECTED" if hit.discrepancy else "VERIFIED" | |
| line = ( | |
| f"{headline}: Transaction found - ${hit.amount:,.2f} at " | |
| f"{hit.merchant} on {hit.date}" | |
| ) | |
| if hit.discrepancy: | |
| line += f" | WARNING: {hit.discrepancy_reason}" | |
| return line | |
| # Backwards-compat aliases. | |
| PlaidClient = PlaidGateway | |
| get_plaid_client = get_plaid_gateway | |
| format_verification_result = summarize_ledger_hit | |
| __all__ = [ | |
| "LedgerHit", | |
| "PlaidGateway", | |
| "summarize_ledger_hit", | |
| "get_plaid_gateway", | |
| # legacy | |
| "TransactionMatch", | |
| "PlaidClient", | |
| "get_plaid_client", | |
| "format_verification_result", | |
| "PLAID_AVAILABLE", | |
| ] | |