rtferraz's picture
Add predefined schemas (FINANCE, ECOMMERCE, HEALTHCARE)
c00ac2c verified
raw
history blame
4.61 kB
"""
Predefined domain schemas for common use cases.
Each schema follows the validated patterns from the research:
- FINANCE_SCHEMA: Based on Nubank nuFormer (arXiv:2507.23267) — 97 special tokens
- ECOMMERCE_SCHEMA: Adapted from ActionPiece (arXiv:2502.13581) + nuFormer patterns
- HEALTHCARE_SCHEMA: Clinical event sequences
"""
from ..schema import DomainSchema, FieldSpec, FieldType
# =============================================================================
# FINANCE SCHEMA — Based on Nubank nuFormer
# sign(2) + amount_bucket(21) + month(12) + dow(7) + dom(31) + hour(24) = 97
# =============================================================================
FINANCE_SCHEMA = DomainSchema(
name="finance",
description=(
"Financial transaction schema following Nubank nuFormer (arXiv:2507.23267). "
"Each transaction = sign + amount bucket + calendar features + text description. "
"~14 tokens per transaction, 2048 context = ~146 transactions."
),
fields=[
FieldSpec(name="amount_sign", field_type=FieldType.SIGN, prefix="AMT_SIGN"),
FieldSpec(name="amount", field_type=FieldType.NUMERICAL_CONTINUOUS, prefix="AMT", n_bins=21),
FieldSpec(name="timestamp", field_type=FieldType.TEMPORAL,
calendar_fields=["month", "dow", "dom", "hour"]),
FieldSpec(name="description", field_type=FieldType.TEXT, prefix="DESC"),
],
)
# =============================================================================
# E-COMMERCE SCHEMA — Adapted from ActionPiece + nuFormer patterns
# =============================================================================
ECOMMERCE_SCHEMA = DomainSchema(
name="ecommerce",
description=(
"E-commerce event schema adapted from ActionPiece (arXiv:2502.13581) "
"and nuFormer patterns. Events: view/cart/purchase/return/wishlist. "
"~16 tokens per event, 2048 context = ~128 events."
),
fields=[
FieldSpec(name="event_type", field_type=FieldType.CATEGORICAL_FIXED, prefix="EVT",
categories=["view", "add_to_cart", "purchase", "return", "wishlist"]),
FieldSpec(name="price", field_type=FieldType.NUMERICAL_CONTINUOUS, prefix="PRICE", n_bins=21),
FieldSpec(name="quantity", field_type=FieldType.NUMERICAL_DISCRETE, prefix="QTY", max_value=10),
FieldSpec(name="category", field_type=FieldType.CATEGORICAL_FIXED, prefix="CAT",
categories=[
"electronics", "clothing", "home_garden", "books", "sports",
"toys", "food_grocery", "health_beauty", "automotive", "office",
"pet_supplies", "jewelry", "music", "movies", "games",
"baby", "tools", "arts_crafts", "industrial", "other",
]),
FieldSpec(name="timestamp", field_type=FieldType.TEMPORAL,
calendar_fields=["month", "dow", "dom", "hour"]),
FieldSpec(name="product_title", field_type=FieldType.TEXT, prefix="TITLE"),
],
)
# =============================================================================
# HEALTHCARE SCHEMA — Clinical event sequences
# =============================================================================
HEALTHCARE_SCHEMA = DomainSchema(
name="healthcare",
description=(
"Clinical event schema for healthcare sequences. "
"Events: diagnosis/procedure/lab/medication/visit."
),
fields=[
FieldSpec(name="event_type", field_type=FieldType.CATEGORICAL_FIXED, prefix="CLIN",
categories=[
"diagnosis", "procedure", "lab_result", "medication",
"visit_inpatient", "visit_outpatient", "visit_er",
"imaging", "referral", "discharge",
]),
FieldSpec(name="cost", field_type=FieldType.NUMERICAL_CONTINUOUS, prefix="COST", n_bins=21),
FieldSpec(name="severity", field_type=FieldType.CATEGORICAL_FIXED, prefix="SEV",
categories=["low", "moderate", "high", "critical"]),
FieldSpec(name="provider_type", field_type=FieldType.CATEGORICAL_FIXED, prefix="PROV",
categories=[
"pcp", "specialist", "surgeon", "er_physician",
"nurse_practitioner", "therapist", "pharmacist", "other",
]),
FieldSpec(name="timestamp", field_type=FieldType.TEMPORAL,
calendar_fields=["month", "dow", "dom"]),
FieldSpec(name="description", field_type=FieldType.TEXT, prefix="DESC"),
],
)