| """ |
| V-JEPA 2 Custom Inference Handler for Hugging Face Inference Endpoints |
| Model: facebook/vjepa2-vitl-fpc64-256 (Large variant - good balance of performance/resources) |
| |
| For ProofPath video assessment - extracts motion features from skill demonstration videos. |
| """ |
|
|
| from typing import Dict, List, Any, Optional |
| import torch |
| import numpy as np |
| import base64 |
| import io |
| import tempfile |
| import os |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| """ |
| Initialize V-JEPA 2 model for video feature extraction. |
| |
| Args: |
| path: Path to the model directory (provided by HF Inference Endpoints) |
| """ |
| from transformers import AutoVideoProcessor, AutoModel |
| |
| |
| |
| model_id = "facebook/vjepa2-vitl-fpc64-256" |
| |
| |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| |
| self.processor = AutoVideoProcessor.from_pretrained(model_id) |
| self.model = AutoModel.from_pretrained( |
| model_id, |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| device_map="auto" if torch.cuda.is_available() else None, |
| attn_implementation="sdpa" |
| ) |
| |
| if not torch.cuda.is_available(): |
| self.model = self.model.to(self.device) |
| |
| self.model.eval() |
| |
| |
| self.default_num_frames = 64 |
| |
| def _decode_video(self, video_data: Any) -> torch.Tensor: |
| """ |
| Decode video from various input formats. |
| |
| Supports: |
| - Base64 encoded video bytes |
| - URL to video file |
| - Raw bytes |
| """ |
| from torchcodec.decoders import VideoDecoder |
| |
| |
| if isinstance(video_data, str): |
| if video_data.startswith(('http://', 'https://')): |
| |
| vr = VideoDecoder(video_data) |
| elif video_data.startswith('data:'): |
| |
| header, encoded = video_data.split(',', 1) |
| video_bytes = base64.b64decode(encoded) |
| |
| with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
| f.write(video_bytes) |
| temp_path = f.name |
| vr = VideoDecoder(temp_path) |
| os.unlink(temp_path) |
| else: |
| |
| video_bytes = base64.b64decode(video_data) |
| with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
| f.write(video_bytes) |
| temp_path = f.name |
| vr = VideoDecoder(temp_path) |
| os.unlink(temp_path) |
| elif isinstance(video_data, bytes): |
| with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: |
| f.write(video_data) |
| temp_path = f.name |
| vr = VideoDecoder(temp_path) |
| os.unlink(temp_path) |
| else: |
| raise ValueError(f"Unsupported video input type: {type(video_data)}") |
| |
| return vr |
| |
| def _sample_frames( |
| self, |
| video_decoder, |
| num_frames: int = 64, |
| sampling_strategy: str = "uniform" |
| ) -> torch.Tensor: |
| """ |
| Sample frames from video decoder. |
| |
| Args: |
| video_decoder: torchcodec VideoDecoder instance |
| num_frames: Number of frames to sample |
| sampling_strategy: "uniform" or "random" |
| """ |
| |
| metadata = video_decoder.metadata |
| total_frames = metadata.num_frames if hasattr(metadata, 'num_frames') else 1000 |
| |
| if sampling_strategy == "uniform": |
| |
| if total_frames <= num_frames: |
| frame_idx = np.arange(total_frames) |
| else: |
| frame_idx = np.linspace(0, total_frames - 1, num_frames, dtype=int) |
| elif sampling_strategy == "random": |
| frame_idx = np.sort(np.random.choice(total_frames, min(num_frames, total_frames), replace=False)) |
| else: |
| |
| frame_idx = np.arange(min(num_frames, total_frames)) |
| |
| |
| frames = video_decoder.get_frames_at(indices=frame_idx.tolist()).data |
| |
| return frames |
| |
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Process video and extract V-JEPA 2 features. |
| |
| Expected input format: |
| { |
| "inputs": <base64_video_string or video_url>, |
| "parameters": { |
| "num_frames": 64, # Optional: number of frames to sample |
| "sampling_strategy": "uniform", # Optional: "uniform" or "random" |
| "return_predictor": true, # Optional: also return predictor features |
| "pooling": "mean" # Optional: "mean", "cls", or "none" |
| } |
| } |
| |
| Returns: |
| { |
| "encoder_features": [...], # Encoder output features |
| "predictor_features": [...], # Optional predictor features |
| "feature_shape": [T, D], # Shape of features |
| } |
| """ |
| |
| inputs = data.get("inputs") |
| if inputs is None: |
| inputs = data.get("video") |
| if inputs is None: |
| raise ValueError("No video input provided. Use 'inputs' or 'video' key.") |
| |
| |
| params = data.get("parameters", {}) |
| num_frames = params.get("num_frames", self.default_num_frames) |
| sampling_strategy = params.get("sampling_strategy", "uniform") |
| return_predictor = params.get("return_predictor", False) |
| pooling = params.get("pooling", "mean") |
| |
| try: |
| |
| video_decoder = self._decode_video(inputs) |
| frames = self._sample_frames(video_decoder, num_frames, sampling_strategy) |
| |
| |
| processed = self.processor(frames, return_tensors="pt") |
| processed = {k: v.to(self.model.device) for k, v in processed.items()} |
| |
| |
| with torch.no_grad(): |
| outputs = self.model(**processed) |
| |
| |
| encoder_features = outputs.last_hidden_state |
| |
| |
| if pooling == "mean": |
| encoder_pooled = encoder_features.mean(dim=1) |
| elif pooling == "cls": |
| encoder_pooled = encoder_features[:, 0, :] |
| else: |
| encoder_pooled = encoder_features |
| |
| result = { |
| "encoder_features": encoder_pooled.cpu().numpy().tolist(), |
| "feature_shape": list(encoder_pooled.shape), |
| } |
| |
| |
| if return_predictor and hasattr(outputs, 'predictor_output'): |
| predictor_features = outputs.predictor_output.last_hidden_state |
| if pooling == "mean": |
| predictor_pooled = predictor_features.mean(dim=1) |
| elif pooling == "cls": |
| predictor_pooled = predictor_features[:, 0, :] |
| else: |
| predictor_pooled = predictor_features |
| result["predictor_features"] = predictor_pooled.cpu().numpy().tolist() |
| result["predictor_shape"] = list(predictor_pooled.shape) |
| |
| return result |
| |
| except Exception as e: |
| return {"error": str(e), "error_type": type(e).__name__} |
|
|