viral-images / scoring /engine.py
Babajaan's picture
Full Viral Images v1.0 implementation - all modules and configs
6ceaa94 verified
"""Main scoring engine — orchestrates feature extraction, sub-score computation, and aggregation."""
import os
import time
import yaml
import numpy as np
from pathlib import Path
from typing import Dict, Any, Optional
from dataclasses import dataclass, field
from PIL import Image
from .normalizers import ScoreNormalizer
from features.heuristics import compute_heuristic_features
from features.saliency import compute_saliency_features
from features.ocr import compute_ocr_features
from features.quality import compute_quality_features
from features.semantic import compute_semantic_features, SemanticFeatureExtractor
from features.neural_richness import compute_neural_richness_proxy
from utils.preprocessing import preprocess_image, validate_image
from models.loader import ModelLoader
@dataclass
class ScoreResponse:
overall_score: float = 0.0
sub_scores: Dict[str, float] = field(default_factory=dict)
confidence: float = 0.85
strengths: list = field(default_factory=list)
weaknesses: list = field(default_factory=list)
suggestions: list = field(default_factory=list)
projected_improvement: float = 0.0
raw_features: Dict[str, Any] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
class ScoringEngine:
"""Orchestrates the full image scoring pipeline."""
def __init__(self, model_loader: Optional[ModelLoader] = None,
config_path: Optional[str] = None):
self.model_loader = model_loader or ModelLoader()
self.config = self._load_config(config_path)
self.normalizer = ScoreNormalizer(self.config)
self.extractor = None # Lazy init
self.warmup_complete = False
def _load_config(self, config_path: Optional[str]) -> dict:
"""Load scoring config from YAML."""
if config_path and os.path.exists(config_path):
with open(config_path, 'r') as f:
return yaml.safe_load(f) or {}
# Try default paths
for path in ['configs/scoring_weights.yaml', 'viral-images/configs/scoring_weights.yaml']:
if os.path.exists(path):
with open(path, 'r') as f:
return yaml.safe_load(f) or {}
return {}
def warmup(self):
"""Pre-load models for faster first inference."""
if self.warmup_complete:
return
self.model_loader.warmup()
self.warmup_complete = True
# Initialize semantic extractor with CLIP
clip_model = self.model_loader.get_clip_model()
if clip_model is not None:
model, processor = clip_model
self.extractor = SemanticFeatureExtractor(model, processor, self.model_loader.get_device())
def score(self, image_input, concept: str = "", audience: str = "General",
use_case: str = "social_media") -> ScoreResponse:
"""
Score an image across 8 dimensions.
Args:
image_input: PIL Image, numpy array, or file path
concept: User-declared concept/theme
audience: Target audience
use_case: Use case preset key
Returns:
ScoreResponse with all scores, explanations, and suggestions.
"""
t0 = time.time()
response = ScoreResponse()
# --- 1. Preprocess ---
try:
img = preprocess_image(image_input)
except Exception as e:
response.metadata["error"] = f"Preprocessing failed: {e}"
return response
val_error = validate_image(img)
if val_error:
response.metadata["error"] = val_error
return response
# --- 2. Initialize extractor if needed ---
if self.extractor is None:
clip_model = self.model_loader.get_clip_model()
if clip_model is not None:
model, processor = clip_model
self.extractor = SemanticFeatureExtractor(model, processor, self.model_loader.get_device())
# --- 3. Extract features ---
raw_features = {}
# Heuristic features (always available)
heuristic_feats = compute_heuristic_features(img)
raw_features.update({f"heuristic_{k}": v for k, v in heuristic_feats.items()})
# Saliency features (with heuristic fallback)
saliency_feats = compute_saliency_features(img)
raw_features.update({f"saliency_{k}": v for k, v in saliency_feats.items()})
# OCR features (with heuristic fallback)
ocr_feats = compute_ocr_features(img)
raw_features.update({f"ocr_{k}": v for k, v in ocr_feats.items()})
# Quality features
quality_feats = compute_quality_features(img)
raw_features.update({f"quality_{k}": v for k, v in quality_feats.items()})
# Semantic features (CLIP-based)
semantic_feats = compute_semantic_features(
img, concept, audience, use_case, self.extractor
)
raw_features.update({f"semantic_{k}": v for k, v in semantic_feats.items()})
# --- 4. Compute sub-scores ---
sub_scores = {}
# Concept Match
sub_scores["concept_match"] = self.normalizer.normalize_concept_match(
semantic_feats.get("composite_cosine", 0.0)
)
# Visual Focus
sub_scores["visual_focus"] = self.normalizer.normalize_visual_focus(
saliency_feats.get("peak_saliency", 0.3),
saliency_feats.get("center_saliency", 0.2),
saliency_feats.get("top20_fraction", 0.2)
)
# Readability
sub_scores["readability"] = self.normalizer.normalize_readability(
ocr_feats.get("avg_ocr_confidence", 0.0),
ocr_feats.get("text_coverage", 0.0),
ocr_feats.get("word_count", 0),
ocr_feats.get("has_text", False)
)
# Complexity Balance
sub_scores["complexity_balance"] = self.normalizer.normalize_complexity_balance(
heuristic_feats.get("edge_density", 0.08),
heuristic_feats.get("color_entropy", 3.0)
)
# Communication Clarity
sub_scores["communication_clarity"] = self.normalizer.normalize_communication_clarity(
heuristic_feats.get("whitespace_ratio", 0.05),
heuristic_feats.get("contrast", 0.4),
heuristic_feats.get("symmetry_lr", 0.5),
heuristic_feats.get("sharpness", 0.3)
)
# Neural Richness (proxy)
neural_proxy = compute_neural_richness_proxy(
semantic_feats, saliency_feats, quality_feats, heuristic_feats
)
sub_scores["neural_richness"] = self.normalizer.normalize_neural_richness(neural_proxy)
# Memorability Proxy
sub_scores["memorability_proxy"] = self.normalizer.normalize_memorability_proxy(
quality_feats.get("nima_aesthetic_proxy", 0.5),
quality_feats.get("colorfulness", 0.5),
quality_feats.get("sharpness", 0.3)
)
# Improvement Potential
sub_scores["improvement_potential"] = self.normalizer.normalize_improvement_potential(
sub_scores
)
# --- 5. Compute confidence ---
confidence = self._compute_confidence(
concept, img.size[0], img.size[1],
ocr_feats.get("has_text", False),
self.extractor is not None
)
# --- 6. Aggregate ---
overall = self._aggregate(sub_scores, use_case)
# --- 7. Strengths and weaknesses ---
strengths, weaknesses = self._identify_strengths_weaknesses(sub_scores)
# --- 8. Build response ---
response.overall_score = round(overall, 1)
response.sub_scores = {k: round(v, 1) for k, v in sub_scores.items()}
response.confidence = round(confidence, 2)
response.strengths = strengths
response.weaknesses = weaknesses
response.raw_features = {k: round(v, 4) if isinstance(v, float) else v
for k, v in raw_features.items()}
response.metadata = {
"processing_time_ms": round((time.time() - t0) * 1000, 1),
"models_used": ["clip-vit-base-patch32", "heuristic-fallback",
"saliency-heuristic", "ocr-heuristic"],
"neural_richness_mode": "proxy",
"api_version": "1.0.0",
"image_size": f"{img.size[0]}x{img.size[1]}",
}
return response
def _compute_confidence(self, concept: str, width: int, height: int,
has_text: bool, has_clip: bool) -> float:
"""Compute confidence in the scoring."""
base = self.config.get("scoring", {}).get("confidence_base", 0.85)
if not concept or len(concept.strip()) < 3:
base -= self.config.get("scoring", {}).get("confidence_no_concept_penalty", 0.15)
if width < 200 or height < 200:
base -= self.config.get("scoring", {}).get("confidence_low_res_penalty", 0.10)
if not has_text and concept and "text" in concept.lower():
base -= self.config.get("scoring", {}).get("confidence_ocr_uncertain_penalty", 0.10)
if not has_clip:
base -= 0.25 # major signal missing
return max(0.4, base)
def _aggregate(self, sub_scores: Dict[str, float], use_case: str) -> float:
"""Weighted aggregation of sub-scores."""
presets = self.config.get("presets", {})
weights = presets.get(use_case, presets.get("default", {}))
if not weights:
# Uniform weights fallback
weights = {k: 1.0 / len(sub_scores) for k in sub_scores}
total = sum(weights.get(k, 0.0) * v for k, v in sub_scores.items())
weight_sum = sum(weights.get(k, 0.0) for k in sub_scores)
if weight_sum > 0:
return total / weight_sum
return np.mean(list(sub_scores.values())) if sub_scores else 0.0
def _identify_strengths_weaknesses(self, sub_scores: Dict[str, float],
threshold_high: float = 70.0,
threshold_low: float = 50.0):
"""Identify top strengths and weaknesses from sub-scores."""
readable_names = {
"concept_match": "Concept Match",
"visual_focus": "Visual Focus",
"readability": "Readability",
"complexity_balance": "Complexity Balance",
"communication_clarity": "Communication Clarity",
"neural_richness": "Predicted Neural Richness",
"memorability_proxy": "Memorability",
"improvement_potential": "Improvement Potential",
}
strengths = []
weaknesses = []
sorted_scores = sorted(sub_scores.items(), key=lambda x: -x[1])
for name, score in sorted_scores:
readable = readable_names.get(name, name)
if score >= threshold_high:
strengths.append(f"Strong {readable.lower()}: scored {score:.0f}/100")
elif score <= threshold_low:
weaknesses.append(f"Weak {readable.lower()}: scored only {score:.0f}/100")
return strengths[:5], weaknesses[:5]