Spaces:
Sleeping
Sleeping
| # 9-point gaze calibration for L2CS-Net | |
| # Maps raw gaze angles -> normalised screen coords via polynomial least-squares. | |
| # Centre point is the bias reference (subtracted from all readings). | |
| import numpy as np | |
| from dataclasses import dataclass, field | |
| # 3x3 grid, centre first (bias ref), then row by row | |
| DEFAULT_TARGETS = [ | |
| (0.5, 0.5), | |
| (0.15, 0.15), (0.50, 0.15), (0.85, 0.15), | |
| (0.15, 0.50), (0.85, 0.50), | |
| (0.15, 0.85), (0.50, 0.85), (0.85, 0.85), | |
| ] | |
| class _PointSamples: | |
| target_x: float | |
| target_y: float | |
| yaws: list = field(default_factory=list) | |
| pitches: list = field(default_factory=list) | |
| def _iqr_filter(values): | |
| if len(values) < 4: | |
| return values | |
| arr = np.array(values) | |
| q1, q3 = np.percentile(arr, [25, 75]) | |
| iqr = q3 - q1 | |
| lo, hi = q1 - 1.5 * iqr, q3 + 1.5 * iqr | |
| return arr[(arr >= lo) & (arr <= hi)].tolist() | |
| class GazeCalibration: | |
| def __init__(self, targets=None): | |
| self._targets = targets or list(DEFAULT_TARGETS) | |
| self._points = [_PointSamples(tx, ty) for tx, ty in self._targets] | |
| self._current_idx = 0 | |
| self._fitted = False | |
| self._W = None # (6, 2) polynomial weights | |
| self._yaw_bias = 0.0 | |
| self._pitch_bias = 0.0 | |
| def num_points(self): | |
| return len(self._targets) | |
| def current_index(self): | |
| return self._current_idx | |
| def current_target(self): | |
| if self._current_idx < len(self._targets): | |
| return self._targets[self._current_idx] | |
| return self._targets[-1] | |
| def is_complete(self): | |
| return self._current_idx >= len(self._targets) | |
| def is_fitted(self): | |
| return self._fitted | |
| def collect_sample(self, yaw_rad, pitch_rad): | |
| if self._current_idx >= len(self._points): | |
| return | |
| pt = self._points[self._current_idx] | |
| pt.yaws.append(float(yaw_rad)) | |
| pt.pitches.append(float(pitch_rad)) | |
| def advance(self): | |
| self._current_idx += 1 | |
| return self._current_idx < len(self._targets) | |
| def _poly_features(yaw, pitch): | |
| # [yaw^2, pitch^2, yaw*pitch, yaw, pitch, 1] | |
| return np.array([yaw**2, pitch**2, yaw * pitch, yaw, pitch, 1.0], | |
| dtype=np.float64) | |
| def fit(self): | |
| # bias from centre point (index 0) | |
| center = self._points[0] | |
| center_yaws = _iqr_filter(center.yaws) | |
| center_pitches = _iqr_filter(center.pitches) | |
| if len(center_yaws) < 2 or len(center_pitches) < 2: | |
| return False | |
| self._yaw_bias = float(np.median(center_yaws)) | |
| self._pitch_bias = float(np.median(center_pitches)) | |
| rows_A, rows_B = [], [] | |
| for pt in self._points: | |
| clean_yaws = _iqr_filter(pt.yaws) | |
| clean_pitches = _iqr_filter(pt.pitches) | |
| if len(clean_yaws) < 2 or len(clean_pitches) < 2: | |
| continue | |
| med_yaw = float(np.median(clean_yaws)) - self._yaw_bias | |
| med_pitch = float(np.median(clean_pitches)) - self._pitch_bias | |
| rows_A.append(self._poly_features(med_yaw, med_pitch)) | |
| rows_B.append([pt.target_x, pt.target_y]) | |
| if len(rows_A) < 5: | |
| return False | |
| A = np.array(rows_A, dtype=np.float64) | |
| B = np.array(rows_B, dtype=np.float64) | |
| try: | |
| W, _, _, _ = np.linalg.lstsq(A, B, rcond=None) | |
| self._W = W | |
| self._fitted = True | |
| return True | |
| except np.linalg.LinAlgError: | |
| return False | |
| def predict(self, yaw_rad, pitch_rad): | |
| if not self._fitted or self._W is None: | |
| return 0.5, 0.5 | |
| feat = self._poly_features(yaw_rad - self._yaw_bias, pitch_rad - self._pitch_bias) | |
| xy = feat @ self._W | |
| return float(np.clip(xy[0], 0, 1)), float(np.clip(xy[1], 0, 1)) | |
| def to_dict(self): | |
| return { | |
| "targets": self._targets, | |
| "fitted": self._fitted, | |
| "current_index": self._current_idx, | |
| "W": self._W.tolist() if self._W is not None else None, | |
| "yaw_bias": self._yaw_bias, | |
| "pitch_bias": self._pitch_bias, | |
| } | |
| def from_dict(cls, d): | |
| cal = cls(targets=d.get("targets", DEFAULT_TARGETS)) | |
| cal._fitted = d.get("fitted", False) | |
| cal._current_idx = d.get("current_index", 0) | |
| cal._yaw_bias = d.get("yaw_bias", 0.0) | |
| cal._pitch_bias = d.get("pitch_bias", 0.0) | |
| w = d.get("W") | |
| if w is not None: | |
| cal._W = np.array(w, dtype=np.float64) | |
| return cal | |