YT-AI-Automation / backend /src /utils /eta_tracker.py
github-actions
Sync Docker Space
5f3e9f5
import json
import os
import threading
from typing import TypedDict, Dict, List, Optional, cast
class ModelSpeed(TypedDict):
chars_per_second: float
samples: int
class ScreenshotSpeed(TypedDict):
seconds_per_screenshot: float
samples: int
class VerificationSpeed(TypedDict):
average_seconds: float
samples: int
class ETAData(TypedDict):
models: Dict[str, ModelSpeed]
screenshots: ScreenshotSpeed
verification: VerificationSpeed
class ProcessEstimateSample(TypedDict, total=False):
input_chars: int
seconds: float
resolution: str
concurrent: bool
class ProcessEstimateModel(TypedDict):
seconds_per_char: float
samples: int
runs: List[ProcessEstimateSample]
class ProcessEstimateData(TypedDict):
min_samples: int
models: Dict[str, ProcessEstimateModel]
# ─── Resolution & concurrency feature engineering ───────────────────────────
#
# Video export time scales roughly with pixel count, while AI/screenshot
# stages don't. The factors below are *priors* — they're only used when a
# bucket has fewer than ``_BUCKET_MIN_SAMPLES`` real observations. As more
# data comes in, the factors are computed empirically per (model, resolution).
RESOLUTION_ALIASES: Dict[str, str] = {
"720": "720p",
"720p": "720p",
"hd": "720p",
"1080": "1080p",
"1080p": "1080p",
"fhd": "1080p",
"1440": "1440p",
"1440p": "1440p",
"qhd": "1440p",
"2k": "1440p",
"4k": "4k",
"2160": "4k",
"2160p": "4k",
"uhd": "4k",
}
# Default multipliers vs 1080p baseline, tuned to typical PowerPoint MP4
# export costs. These are only consulted as priors; the tracker will
# replace them with observed data once a (model, resolution) bucket has
# enough samples.
_RESOLUTION_PRIOR_FACTOR: Dict[str, float] = {
"720p": 0.7,
"1080p": 1.0,
"1440p": 1.6,
"4k": 2.5,
}
# Concurrency slowdown prior — running two pipelines in parallel roughly
# 1.5x's a single run because the AI/PowerPoint stages contend for the
# same resources.
_CONCURRENCY_PRIOR_FACTOR = 1.5
_BUCKET_MIN_SAMPLES = 3 # smallest bucket that overrides the prior
DEFAULT_RESOLUTION = "1080p"
def normalize_resolution(label) -> str:
"""Canonicalize a user-supplied resolution string to a known bucket.
Falls back to ``"1080p"`` so legacy samples without an explicit
resolution land in the same bucket as the modal default — which is
what the user asked for when they said *label existing data as
1080p*.
"""
if isinstance(label, (list, tuple)) and len(label) >= 2:
# Stored as ``[width, height]`` in the run settings — map back to
# the closest named bucket by total pixel count.
try:
pixels = int(label[0]) * int(label[1])
except (TypeError, ValueError):
return DEFAULT_RESOLUTION
if pixels >= 3840 * 2160 * 0.9:
return "4k"
if pixels >= 2560 * 1440 * 0.9:
return "1440p"
if pixels >= 1920 * 1080 * 0.9:
return "1080p"
return "720p"
text = str(label or "").strip().lower()
if not text:
return DEFAULT_RESOLUTION
return RESOLUTION_ALIASES.get(text, DEFAULT_RESOLUTION)
class ETATracker:
"""Tracks and predicts completion times based on historical runs."""
def __init__(
self,
storage_path: str = "config/estimated_times.json",
process_storage_path: str = "config/process_time_estimates.json",
) -> None:
self.storage_path = storage_path
self.process_storage_path = process_storage_path
self._lock = threading.Lock()
self._migrated_on_load = False
self.data: ETAData = self._load_data()
self.process_data: ProcessEstimateData = self._load_process_data()
# Persist the resolution/concurrent backfill so we don't re-do it
# on every boot and so the on-disk file matches the in-memory shape.
if self._migrated_on_load:
self._save_process_data()
def _load_data(self) -> ETAData:
"""Load historical timing data from disk, or initialize defaults."""
if os.path.exists(self.storage_path):
try:
with open(self.storage_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Migrating old data if necessary
if "verification" not in data:
data["verification"] = {"average_seconds": 15.0, "samples": 0}
return cast(ETAData, data)
except Exception as e:
print(f"⚠️ Error loading ETA data: {e}")
# Defaults if no file exists
return cast(ETAData, {
"models": {
"default": {"chars_per_second": 500.0, "samples": 0},
"fast": {"chars_per_second": 1500.0, "samples": 0},
"kimi": {"chars_per_second": 300.0, "samples": 0},
"deepseek": {"chars_per_second": 400.0, "samples": 0},
"devstral": {"chars_per_second": 1000.0, "samples": 0}
},
"screenshots": {"seconds_per_screenshot": 1.5, "samples": 0},
"verification": {"average_seconds": 15.0, "samples": 0}
})
def _load_process_data(self) -> ProcessEstimateData:
"""Load successful process timing data used for user-facing ETAs."""
if os.path.exists(self.process_storage_path):
try:
with open(self.process_storage_path, 'r', encoding='utf-8') as f:
data = json.load(f)
data.setdefault("min_samples", 10)
data.setdefault("models", {})
# Backfill the new resolution/concurrent fields onto pre-
# existing samples — per the user's request, runs without
# an explicit resolution count as 1080p (the modal default
# and what the existing data was almost certainly captured
# at).
migrated = False
for model_data in data["models"].values():
for run in model_data.get("runs", []) or []:
if "resolution" not in run:
run["resolution"] = DEFAULT_RESOLUTION
migrated = True
if "concurrent" not in run:
run["concurrent"] = False
migrated = True
self._migrated_on_load = migrated
return cast(ProcessEstimateData, data)
except Exception as e:
print(f"Warning: Error loading process ETA data: {e}")
self._migrated_on_load = False
return cast(ProcessEstimateData, {"min_samples": 10, "models": {}})
def _save_data(self):
"""Save the current timing data to disk."""
try:
os.makedirs(os.path.dirname(self.storage_path), exist_ok=True)
with open(self.storage_path, 'w', encoding='utf-8') as f:
json.dump(self.data, f, indent=4)
except Exception as e:
print(f"⚠️ Error saving ETA data: {e}")
def _save_process_data(self):
"""Save process ETA samples to disk."""
try:
os.makedirs(os.path.dirname(self.process_storage_path), exist_ok=True)
with open(self.process_storage_path, 'w', encoding='utf-8') as f:
json.dump(self.process_data, f, indent=4)
except Exception as e:
print(f"Warning: Error saving process ETA data: {e}")
def record_verification(self, seconds: float):
"""Record the time taken for a verification pass."""
if seconds <= 0:
return
ALPHA = 0.2 # Slower moving average for verification
ver_data = self.data["verification"]
if ver_data["samples"] == 0:
ver_data["average_seconds"] = seconds
else:
ver_data["average_seconds"] = (ALPHA * seconds) + ((1 - ALPHA) * ver_data["average_seconds"])
ver_data["samples"] += 1
self._save_data()
def record_completion(self, model_choice, input_chars, ai_seconds, screenshot_count, screenshot_seconds, use_cache=False):
"""
Record a successful run to improve future predictions.
AI timing is only recorded when a real AI call happened — filtered by
``ai_seconds`` rather than the user's ``use_cache`` preference, so cache
misses (where the AI actually ran) still feed the ETA model even when
caching is enabled globally.
"""
del use_cache # Preserved for API compatibility; filter by duration below.
updated = False
ALPHA = 0.3 # Moving average weight (30% new data, 70% historical)
# 1. Update AI Speed — filter cache hits by duration (<0.5s is a hit).
if ai_seconds > 0.5 and input_chars > 0:
if "models" not in self.data:
self.data["models"] = {}
if model_choice not in self.data["models"]:
self.data["models"][model_choice] = cast(ModelSpeed, {"chars_per_second": 500.0, "samples": 0})
model_data = self.data["models"][model_choice]
current_cps = input_chars / ai_seconds
if model_data["samples"] == 0:
model_data["chars_per_second"] = current_cps
else:
model_data["chars_per_second"] = (ALPHA * current_cps) + ((1 - ALPHA) * model_data["chars_per_second"])
model_data["samples"] += 1
updated = True
# 2. Update Screenshot Speed (if screenshots taken)
if screenshot_count > 0 and screenshot_seconds > 0:
screens_data = self.data["screenshots"]
current_sps = screenshot_seconds / screenshot_count
if screens_data["samples"] == 0:
screens_data["seconds_per_screenshot"] = current_sps
else:
screens_data["seconds_per_screenshot"] = (ALPHA * current_sps) + ((1 - ALPHA) * screens_data["seconds_per_screenshot"])
screens_data["samples"] += 1
updated = True
if updated:
self._save_data()
def record_process_completion(
self,
model_choice: str,
input_chars: int,
total_seconds: float,
resolution: Optional[object] = None,
concurrent: bool = False,
) -> None:
"""Record a successful end-to-end process sample.
``resolution`` and ``concurrent`` get folded into per-bucket
statistics so the predictor can charge a 4K run more time than a
1080p one and bake in the concurrent-pipeline slowdown rather
than averaging it away. Older callers that don't pass the new
kwargs land in the ``1080p`` / non-concurrent bucket — the same
bucket pre-existing samples migrate into.
"""
if input_chars <= 0 or total_seconds <= 0:
return
model = str(model_choice or "default")
sample: ProcessEstimateSample = {
"input_chars": int(input_chars),
"seconds": round(float(total_seconds), 3),
"resolution": normalize_resolution(resolution),
"concurrent": bool(concurrent),
}
with self._lock:
models = self.process_data.setdefault("models", {})
if model not in models:
models[model] = cast(ProcessEstimateModel, {
"seconds_per_char": 0.0,
"samples": 0,
"runs": [],
})
model_data = models[model]
runs = model_data.setdefault("runs", [])
runs.append(sample)
del runs[:-100]
total_chars = sum(max(0, int(r.get("input_chars", 0))) for r in runs)
total_time = sum(max(0.0, float(r.get("seconds", 0))) for r in runs)
model_data["samples"] = len(runs)
model_data["seconds_per_char"] = total_time / total_chars if total_chars > 0 else 0.0
self._save_process_data()
def _bucket_seconds_per_char(self, runs: List[ProcessEstimateSample], **filters) -> Optional[float]:
"""Mean seconds-per-character across ``runs`` matching ``filters``.
Returns ``None`` when fewer than ``_BUCKET_MIN_SAMPLES`` runs match
— caller falls back to a wider bucket or a prior multiplier.
"""
matched = [
r for r in runs
if all(r.get(key) == value for key, value in filters.items())
]
if len(matched) < _BUCKET_MIN_SAMPLES:
return None
chars = sum(max(0, int(r.get("input_chars", 0))) for r in matched)
secs = sum(max(0.0, float(r.get("seconds", 0))) for r in matched)
if chars <= 0:
return None
return secs / chars
def _resolution_factor(self, runs: List[ProcessEstimateSample], resolution: str) -> float:
"""Multiplier vs the 1080p baseline, observed-or-prior."""
if resolution == DEFAULT_RESOLUTION:
return 1.0
baseline = self._bucket_seconds_per_char(runs, resolution=DEFAULT_RESOLUTION)
observed = self._bucket_seconds_per_char(runs, resolution=resolution)
if baseline and observed and baseline > 0:
return observed / baseline
return _RESOLUTION_PRIOR_FACTOR.get(resolution, 1.0)
def _concurrency_factor(self, runs: List[ProcessEstimateSample]) -> float:
"""Multiplier for concurrent vs solo pipeline runs, observed-or-prior."""
solo = self._bucket_seconds_per_char(runs, concurrent=False)
concurrent = self._bucket_seconds_per_char(runs, concurrent=True)
if solo and concurrent and solo > 0:
return concurrent / solo
return _CONCURRENCY_PRIOR_FACTOR
def predict_process_time(
self,
model_choice: str,
input_chars: int,
resolution: Optional[object] = None,
concurrent: bool = False,
) -> Optional[int]:
"""Predict process seconds factoring in resolution and concurrency.
Stays silent (returns ``None``) until the selected model has at
least ``min_samples`` total runs — matches the user's ask of
"only show ETA after 10 processes".
"""
if input_chars <= 0:
return None
model_data = self.process_data.get("models", {}).get(str(model_choice or "default"))
if not model_data:
return None
min_samples = int(self.process_data.get("min_samples", 10))
runs = model_data.get("runs", []) or []
if len(runs) < min_samples:
return None
target_resolution = normalize_resolution(resolution)
is_concurrent = bool(concurrent)
# Use the most specific bucket that has enough samples; fall back
# to multiplying the broader bucket by an observed-or-prior factor.
bucket_filters: List[Dict[str, object]] = [
{"resolution": target_resolution, "concurrent": is_concurrent},
{"resolution": target_resolution},
{"concurrent": is_concurrent},
{},
]
spc: Optional[float] = None
used_filters: Dict[str, object] = {}
for filt in bucket_filters:
spc = self._bucket_seconds_per_char(runs, **filt)
if spc is not None:
used_filters = filt
break
if spc is None or spc <= 0:
spc = float(model_data.get("seconds_per_char", 0.0))
if spc <= 0:
return None
predicted = spc * input_chars
if "resolution" not in used_filters and target_resolution != DEFAULT_RESOLUTION:
predicted *= self._resolution_factor(runs, target_resolution)
if "concurrent" not in used_filters and is_concurrent:
predicted *= self._concurrency_factor(runs)
return max(5, round(predicted))
def predict_total_time(self, model_choice, input_chars, estimated_screenshots=10, use_cache=False, enable_verification=True):
"""
Predict total seconds required for generation.
"""
total_seconds = 0.0
# 1. AI Generation Time
if not use_cache and input_chars > 0:
model_data = self.data["models"].get(model_choice, self.data["models"].get("default"))
if model_data and model_data["chars_per_second"] > 0:
ai_seconds = input_chars / model_data["chars_per_second"]
total_seconds += ai_seconds
# 2. Verification Time (if enabled)
if enable_verification:
ver_data = self.data["verification"]
# Assume 1 pass on average (the system does up to 3, but 1 is most common)
total_seconds += ver_data["average_seconds"]
# 3. Add fixed overhead (network/process spin up)
total_seconds += 3.0
# 4. Screenshot Rendering Time
if estimated_screenshots > 0:
screens_data = self.data["screenshots"]
screen_seconds = estimated_screenshots * screens_data["seconds_per_screenshot"]
total_seconds += screen_seconds
return max(5.0, round(total_seconds))
eta_tracker = ETATracker()