| import torch |
| import cv2 |
| import numpy as np |
| from PIL import Image |
| from vljepa.config import Config |
| from vljepa.models import VLJepa |
| from vljepa.utils import nms |
|
|
| def load_model(checkpoint_path, device="cpu"): |
| config = Config() |
| config.device = device |
| model = VLJepa(config) |
| |
| print(f"Loading weights from {checkpoint_path}...") |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) |
| model.predictor.load_state_dict(checkpoint["predictor_state_dict"]) |
| model.y_encoder.projection.load_state_dict(checkpoint["y_projection_state_dict"]) |
| |
| model.eval() |
| return model, config |
|
|
| def extract_frames(video_path, num_frames=16): |
| cap = cv2.VideoCapture(video_path) |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| if total_frames <= 0: |
| return [] |
| |
| indices = np.linspace(0, total_frames - 1, num_frames).astype(int) |
| frames = [] |
| for idx in indices: |
| cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
| ret, frame = cap.read() |
| if ret: |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| frames.append(frame) |
| cap.release() |
| return frames |
|
|
| def main(): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| checkpoint_path = "best.pth" |
| video_path = "sample_video.mp4" |
| query = "a person is opening a door" |
| |
| model, config = load_model(checkpoint_path, device) |
| |
| |
| |
| print(f"Ready for inference on {device}.") |
| print(f"Model architecture: {config.clip_model} + {config.predictor_model} (LoRA) + {config.text_model}") |
| |
| |
| query_tokens = model.query_encoder.tokenize([query], device=device) |
| |
| |
| with torch.no_grad(): |
| text_embedding = model.encode_text([query], device=device) |
| |
| print(f"Query: '{query}'") |
| print(f"Text embedding shape: {text_embedding.shape}") |
| print("\nTo perform full temporal localization, use the infer.py script which implements sliding window and NMS.") |
|
|
| if __name__ == "__main__": |
| main() |
|
|