IntegrationTest / models /gaze_calibration.py
Abdelrahman Almatrooshi
Integrate L2CS-Net gaze estimation
2eba0cc
# 9-point gaze calibration for L2CS-Net
# Maps raw gaze angles -> normalised screen coords via polynomial least-squares.
# Centre point is the bias reference (subtracted from all readings).
import numpy as np
from dataclasses import dataclass, field
# 3x3 grid, centre first (bias ref), then row by row
DEFAULT_TARGETS = [
(0.5, 0.5),
(0.15, 0.15), (0.50, 0.15), (0.85, 0.15),
(0.15, 0.50), (0.85, 0.50),
(0.15, 0.85), (0.50, 0.85), (0.85, 0.85),
]
@dataclass
class _PointSamples:
target_x: float
target_y: float
yaws: list = field(default_factory=list)
pitches: list = field(default_factory=list)
def _iqr_filter(values):
if len(values) < 4:
return values
arr = np.array(values)
q1, q3 = np.percentile(arr, [25, 75])
iqr = q3 - q1
lo, hi = q1 - 1.5 * iqr, q3 + 1.5 * iqr
return arr[(arr >= lo) & (arr <= hi)].tolist()
class GazeCalibration:
def __init__(self, targets=None):
self._targets = targets or list(DEFAULT_TARGETS)
self._points = [_PointSamples(tx, ty) for tx, ty in self._targets]
self._current_idx = 0
self._fitted = False
self._W = None # (6, 2) polynomial weights
self._yaw_bias = 0.0
self._pitch_bias = 0.0
@property
def num_points(self):
return len(self._targets)
@property
def current_index(self):
return self._current_idx
@property
def current_target(self):
if self._current_idx < len(self._targets):
return self._targets[self._current_idx]
return self._targets[-1]
@property
def is_complete(self):
return self._current_idx >= len(self._targets)
@property
def is_fitted(self):
return self._fitted
def collect_sample(self, yaw_rad, pitch_rad):
if self._current_idx >= len(self._points):
return
pt = self._points[self._current_idx]
pt.yaws.append(float(yaw_rad))
pt.pitches.append(float(pitch_rad))
def advance(self):
self._current_idx += 1
return self._current_idx < len(self._targets)
@staticmethod
def _poly_features(yaw, pitch):
# [yaw^2, pitch^2, yaw*pitch, yaw, pitch, 1]
return np.array([yaw**2, pitch**2, yaw * pitch, yaw, pitch, 1.0],
dtype=np.float64)
def fit(self):
# bias from centre point (index 0)
center = self._points[0]
center_yaws = _iqr_filter(center.yaws)
center_pitches = _iqr_filter(center.pitches)
if len(center_yaws) < 2 or len(center_pitches) < 2:
return False
self._yaw_bias = float(np.median(center_yaws))
self._pitch_bias = float(np.median(center_pitches))
rows_A, rows_B = [], []
for pt in self._points:
clean_yaws = _iqr_filter(pt.yaws)
clean_pitches = _iqr_filter(pt.pitches)
if len(clean_yaws) < 2 or len(clean_pitches) < 2:
continue
med_yaw = float(np.median(clean_yaws)) - self._yaw_bias
med_pitch = float(np.median(clean_pitches)) - self._pitch_bias
rows_A.append(self._poly_features(med_yaw, med_pitch))
rows_B.append([pt.target_x, pt.target_y])
if len(rows_A) < 5:
return False
A = np.array(rows_A, dtype=np.float64)
B = np.array(rows_B, dtype=np.float64)
try:
W, _, _, _ = np.linalg.lstsq(A, B, rcond=None)
self._W = W
self._fitted = True
return True
except np.linalg.LinAlgError:
return False
def predict(self, yaw_rad, pitch_rad):
if not self._fitted or self._W is None:
return 0.5, 0.5
feat = self._poly_features(yaw_rad - self._yaw_bias, pitch_rad - self._pitch_bias)
xy = feat @ self._W
return float(np.clip(xy[0], 0, 1)), float(np.clip(xy[1], 0, 1))
def to_dict(self):
return {
"targets": self._targets,
"fitted": self._fitted,
"current_index": self._current_idx,
"W": self._W.tolist() if self._W is not None else None,
"yaw_bias": self._yaw_bias,
"pitch_bias": self._pitch_bias,
}
@classmethod
def from_dict(cls, d):
cal = cls(targets=d.get("targets", DEFAULT_TARGETS))
cal._fitted = d.get("fitted", False)
cal._current_idx = d.get("current_index", 0)
cal._yaw_bias = d.get("yaw_bias", 0.0)
cal._pitch_bias = d.get("pitch_bias", 0.0)
w = d.get("W")
if w is not None:
cal._W = np.array(w, dtype=np.float64)
return cal