File size: 6,916 Bytes
853e25d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | """Feature builder: human-readable inputs -> 133-dim feature vector.
Design principle: we don't fake the speed-dating-specific features (like
partner attractiveness rating, ambition rating from the opposite gender).
These default to population means (0.5 on a 0-1 scale for normalized
features, 5 on a 1-10 scale for raw). The Gottman and survival features —
which dominate SHAP — are computed precisely from the user's answers.
If the loaded feature_columns list contains names we don't know how to
populate, we fill them with 0.5 (a neutral midpoint after normalization)
and log which ones were defaulted. This keeps predictions directionally
correct without pretending to have data we don't have.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any
import numpy as np
from src import gottman_scorer, survival_scorer
log = logging.getLogger(__name__)
@dataclass
class UserInputs:
"""The 15 questions from the Quick Check form."""
# Demographics (4)
age_you: int
gender_you: str # "female" | "male" | "nonbinary" | "prefer_not"
education_you: str # "high_school" | "bachelors" | "masters" | "phd" | "other"
career_you: str # free-text category
# Partner (4)
age_partner: int
gender_partner: str
education_partner: str
career_partner: str
# Dynamic (5)
shared_interests: int # Gottman: 1-5
shared_goals: int # Gottman: 1-5
repair_after_conflict: int # Gottman: 1-5
criticism_frequency: int # Gottman: 1-5 (higher = more criticism)
stonewalling: int # Gottman: 1-5
# Context (2)
relationship_type: str # "love" | "arranged" | "other"
marriage_number: int # 1, 2, 3+
# Optional hidden fields derived/defaulted
years_together: float = 2.0
contempt_frequency: int = 2 # Defaulted if not collected directly
defensiveness: int = 2
know_inner_world: int = 4
partner_knows_me: int = 4
# Educational level -> numeric (matches typical speed-dating coding)
EDU_MAP = {
"high_school": 2,
"associates": 3,
"bachelors": 3,
"masters": 4,
"phd": 5,
"doctorate": 5,
"other": 3,
}
# Gender -> numeric (0 = female, 1 = male, 0.5 = nonbinary/other; matches
# typical binary encoding in legacy datasets)
GENDER_MAP = {
"female": 0,
"male": 1,
"nonbinary": 0.5,
"prefer_not": 0.5,
}
# Career category -> rough income/education prior
# (These match typical speed-dating field-of-study codes.)
CAREER_MAP = {
"tech": 8,
"finance": 7,
"medicine": 9,
"academia": 6,
"law": 7,
"arts": 4,
"education": 5,
"business": 6,
"other": 5,
}
def _career_score(c: str) -> float:
key = c.lower().strip().replace(" ", "_")
return CAREER_MAP.get(key, 5) / 10.0
def _direct_features(u: UserInputs) -> dict[str, float]:
"""Features we can set directly from user input."""
age_avg = (u.age_you + u.age_partner) / 2.0
age_diff = abs(u.age_you - u.age_partner)
return {
# Age features
"age": u.age_you,
"age_o": u.age_partner,
"d_age": age_diff,
"age_avg": age_avg,
# Demographics
"gender": GENDER_MAP.get(u.gender_you.lower(), 0.5),
"samerace": 1.0, # Assumed; not collected in 15-question form
"race": 0.0,
"race_o": 0.0,
# Education
"goal": 3.0, # "serious relationship" default
"field_cd": EDU_MAP.get(u.education_you.lower(), 3),
# Income proxy via career
"income": _career_score(u.career_you) * 100000, # dollar-scaled
# Meta context
"marriage_number": u.marriage_number,
"years_together": u.years_together,
}
def build(u: UserInputs, feature_columns: list[str]) -> np.ndarray:
"""Build the full feature vector aligned to the model's column order.
Returns a float32 array of shape (n_features,).
"""
# Compute engineered features (Gottman + survival)
gottman = gottman_scorer.score(
gottman_scorer.GottmanAnswers(
shared_interests=u.shared_interests,
shared_goals=u.shared_goals,
know_inner_world=u.know_inner_world,
partner_knows_me=u.partner_knows_me,
repair_after_conflict=u.repair_after_conflict,
criticism_frequency=u.criticism_frequency,
contempt_frequency=u.contempt_frequency,
defensiveness=u.defensiveness,
stonewalling=u.stonewalling,
)
)
survival = survival_scorer.compute(
survival_scorer.SurvivalInputs(
age_you=u.age_you,
age_partner=u.age_partner,
marriage_number=u.marriage_number,
relationship_type=u.relationship_type,
years_together=u.years_together,
)
)
direct = _direct_features(u)
# Merge all knowns
known: dict[str, float] = {}
known.update(direct)
known.update(gottman)
known.update(survival)
# Build the vector in the exact order the model expects
vec = np.zeros(len(feature_columns), dtype=np.float32)
unknowns: list[str] = []
for i, col in enumerate(feature_columns):
if col in known:
vec[i] = float(known[col])
else:
# Fallback: try a loose match (e.g., "attr1_1" -> midpoint of 1-10)
vec[i] = _default_for(col)
unknowns.append(col)
if unknowns:
log.debug("Defaulted %d/%d features to population means: %s...",
len(unknowns), len(feature_columns), unknowns[:5])
return vec
def _default_for(col: str) -> float:
"""Heuristic default for a feature we can't derive from user input.
Speed-dating features are typically 1-10 scales (rating attributes) or
0-1 normalized scores. We use 5.0 for rating-style, 0.5 for normalized,
and 0.0 for binary.
"""
name = col.lower()
# Perception/rating features (attr, sinc, intel, fun, amb, shar)
if any(name.startswith(p) for p in ("attr", "sinc", "intel", "fun", "amb", "shar")):
return 5.0
# Importance weight features (sum to ~100 in speed dating)
if "imp" in name or "pref" in name:
return 16.67 # ~100/6 across the six attributes
# Interest-category features (1-10)
if any(k in name for k in ("sports", "tvsports", "exercise", "dining",
"museums", "art", "hiking", "gaming",
"clubbing", "reading", "tv", "theater",
"movies", "concerts", "music", "shopping",
"yoga")):
return 5.0
# Binary/flag features
if name.startswith(("is_", "has_", "same", "dec_")):
return 0.0
# Otherwise: neutral midpoint
return 0.5
|