# 9-point gaze calibration for L2CS-Net # Maps raw gaze angles -> normalised screen coords via polynomial least-squares. # Centre point is the bias reference (subtracted from all readings). import numpy as np from dataclasses import dataclass, field # 3x3 grid, centre first (bias ref), then row by row DEFAULT_TARGETS = [ (0.5, 0.5), (0.15, 0.15), (0.50, 0.15), (0.85, 0.15), (0.15, 0.50), (0.85, 0.50), (0.15, 0.85), (0.50, 0.85), (0.85, 0.85), ] @dataclass class _PointSamples: target_x: float target_y: float yaws: list = field(default_factory=list) pitches: list = field(default_factory=list) def _iqr_filter(values): if len(values) < 4: return values arr = np.array(values) q1, q3 = np.percentile(arr, [25, 75]) iqr = q3 - q1 lo, hi = q1 - 1.5 * iqr, q3 + 1.5 * iqr return arr[(arr >= lo) & (arr <= hi)].tolist() class GazeCalibration: def __init__(self, targets=None): self._targets = targets or list(DEFAULT_TARGETS) self._points = [_PointSamples(tx, ty) for tx, ty in self._targets] self._current_idx = 0 self._fitted = False self._W = None # (6, 2) polynomial weights self._yaw_bias = 0.0 self._pitch_bias = 0.0 @property def num_points(self): return len(self._targets) @property def current_index(self): return self._current_idx @property def current_target(self): if self._current_idx < len(self._targets): return self._targets[self._current_idx] return self._targets[-1] @property def is_complete(self): return self._current_idx >= len(self._targets) @property def is_fitted(self): return self._fitted def collect_sample(self, yaw_rad, pitch_rad): if self._current_idx >= len(self._points): return pt = self._points[self._current_idx] pt.yaws.append(float(yaw_rad)) pt.pitches.append(float(pitch_rad)) def advance(self): self._current_idx += 1 return self._current_idx < len(self._targets) @staticmethod def _poly_features(yaw, pitch): # [yaw^2, pitch^2, yaw*pitch, yaw, pitch, 1] return np.array([yaw**2, pitch**2, yaw * pitch, yaw, pitch, 1.0], dtype=np.float64) def fit(self): # bias from centre point (index 0) center = self._points[0] center_yaws = _iqr_filter(center.yaws) center_pitches = _iqr_filter(center.pitches) if len(center_yaws) < 2 or len(center_pitches) < 2: return False self._yaw_bias = float(np.median(center_yaws)) self._pitch_bias = float(np.median(center_pitches)) rows_A, rows_B = [], [] for pt in self._points: clean_yaws = _iqr_filter(pt.yaws) clean_pitches = _iqr_filter(pt.pitches) if len(clean_yaws) < 2 or len(clean_pitches) < 2: continue med_yaw = float(np.median(clean_yaws)) - self._yaw_bias med_pitch = float(np.median(clean_pitches)) - self._pitch_bias rows_A.append(self._poly_features(med_yaw, med_pitch)) rows_B.append([pt.target_x, pt.target_y]) if len(rows_A) < 5: return False A = np.array(rows_A, dtype=np.float64) B = np.array(rows_B, dtype=np.float64) try: W, _, _, _ = np.linalg.lstsq(A, B, rcond=None) self._W = W self._fitted = True return True except np.linalg.LinAlgError: return False def predict(self, yaw_rad, pitch_rad): if not self._fitted or self._W is None: return 0.5, 0.5 feat = self._poly_features(yaw_rad - self._yaw_bias, pitch_rad - self._pitch_bias) xy = feat @ self._W return float(np.clip(xy[0], 0, 1)), float(np.clip(xy[1], 0, 1)) def to_dict(self): return { "targets": self._targets, "fitted": self._fitted, "current_index": self._current_idx, "W": self._W.tolist() if self._W is not None else None, "yaw_bias": self._yaw_bias, "pitch_bias": self._pitch_bias, } @classmethod def from_dict(cls, d): cal = cls(targets=d.get("targets", DEFAULT_TARGETS)) cal._fitted = d.get("fitted", False) cal._current_idx = d.get("current_index", 0) cal._yaw_bias = d.get("yaw_bias", 0.0) cal._pitch_bias = d.get("pitch_bias", 0.0) w = d.get("W") if w is not None: cal._W = np.array(w, dtype=np.float64) return cal