Spaces:
Paused
Paused
| from __future__ import annotations | |
| from collections import defaultdict | |
| from typing import Any | |
| from osint_env.data.generator import PlatformViews | |
| class ToolRegistry: | |
| def __init__(self, views: PlatformViews): | |
| self.views = views | |
| self.alias_lookup = dict(getattr(views, "alias_lookup", {})) | |
| self._index() | |
| def _normalize_lookup_token(value: str) -> str: | |
| token = str(value or "").strip().lower() | |
| for prefix in ("org_", "loc_", "event_", "post_", "thr_", "thread_", "alias_", "user_"): | |
| if token.startswith(prefix): | |
| token = token[len(prefix) :] | |
| break | |
| return token.replace("_", " ") | |
| def _resolve_user_ids(self, user_id: str) -> list[str]: | |
| user_id = str(user_id or "").strip() | |
| if not user_id: | |
| return [] | |
| resolved = [user_id] | |
| canonical = self.alias_lookup.get(user_id) | |
| if canonical and canonical not in resolved: | |
| resolved.append(canonical) | |
| for alias_id, owner in self.alias_lookup.items(): | |
| if owner == user_id and alias_id not in resolved: | |
| resolved.append(alias_id) | |
| return resolved | |
| def _index(self) -> None: | |
| self.posts_by_user: dict[str, list[dict[str, Any]]] = defaultdict(list) | |
| self.mentions_by_user: dict[str, list[dict[str, Any]]] = defaultdict(list) | |
| self.posts_by_id = {post["post_id"]: post for post in self.views.microblog_posts} | |
| for post in self.views.microblog_posts: | |
| self.posts_by_user[post["user_id"]].append(post) | |
| canonical_user = post.get("canonical_user") | |
| if canonical_user: | |
| self.posts_by_user[canonical_user].append(post) | |
| for m in post.get("mentions", []): | |
| self.mentions_by_user[m].append(post) | |
| self.threads_by_id = {t["thread_id"]: t for t in self.views.forum_threads} | |
| self.activity_by_user: dict[str, list[dict[str, Any]]] = defaultdict(list) | |
| for thread in self.views.forum_threads: | |
| self.activity_by_user[thread["author_id"]].append({"kind": "thread", "thread_id": thread["thread_id"]}) | |
| for c in thread.get("comments", []): | |
| self.activity_by_user[c["user_id"]].append({"kind": "comment", "thread_id": thread["thread_id"]}) | |
| self.profiles_by_user = {p["user_id"]: p for p in self.views.profiles} | |
| def call(self, tool_name: str, args: dict[str, Any]) -> dict[str, Any]: | |
| fn = getattr(self, tool_name, None) | |
| if not fn: | |
| raise ValueError(f"Unknown tool: {tool_name}") | |
| return fn(**args) | |
| def search_posts(self, query: str, time_range: tuple[int, int] | None = None) -> dict[str, Any]: | |
| start, end = time_range or (0, 10**9) | |
| needle = str(query or "").lower() | |
| results = [ | |
| p | |
| for p in self.views.microblog_posts | |
| if start <= p["timestamp"] <= end | |
| and ( | |
| needle in p["text"].lower() | |
| or needle in str(p.get("post_id", "")).lower() | |
| or needle in str(p.get("user_id", "")).lower() | |
| or needle in str(p.get("canonical_user", "")).lower() | |
| or any(needle in str(ref).lower() for ref in p.get("references", [])) | |
| or any(needle in str(ref).lower() for ref in p.get("reference_names", [])) | |
| ) | |
| ] | |
| return {"results": results[:20], "count": len(results)} | |
| def get_post(self, post_id: str) -> dict[str, Any]: | |
| post = self.posts_by_id.get(post_id) | |
| return {"result": post, "found": post is not None} | |
| def get_user_posts(self, user_id: str) -> dict[str, Any]: | |
| results: list[dict[str, Any]] = [] | |
| seen_post_ids: set[str] = set() | |
| for resolved_id in self._resolve_user_ids(user_id): | |
| for post in self.posts_by_user.get(resolved_id, []): | |
| post_id = str(post.get("post_id", "")) | |
| if post_id in seen_post_ids: | |
| continue | |
| seen_post_ids.add(post_id) | |
| results.append(post) | |
| return {"results": results, "count": len(results)} | |
| def get_mentions(self, user_id: str) -> dict[str, Any]: | |
| results: list[dict[str, Any]] = [] | |
| seen_post_ids: set[str] = set() | |
| for resolved_id in self._resolve_user_ids(user_id): | |
| for post in self.mentions_by_user.get(resolved_id, []): | |
| post_id = str(post.get("post_id", "")) | |
| if post_id in seen_post_ids: | |
| continue | |
| seen_post_ids.add(post_id) | |
| results.append(post) | |
| return {"results": results, "count": len(results)} | |
| def search_threads(self, topic: str) -> dict[str, Any]: | |
| needle = str(topic or "").strip().lower() | |
| results = [ | |
| t | |
| for t in self.views.forum_threads | |
| if t["topic"] == topic | |
| or needle in str(t.get("thread_id", "")).lower() | |
| or needle in str(t.get("title", "")).lower() | |
| ] | |
| return {"results": results[:20], "count": len(results)} | |
| def get_thread(self, thread_id: str) -> dict[str, Any]: | |
| thread = self.threads_by_id.get(thread_id) | |
| return {"result": thread, "found": thread is not None} | |
| def get_user_activity(self, user_id: str) -> dict[str, Any]: | |
| acts: list[dict[str, Any]] = [] | |
| seen = set() | |
| for resolved_id in self._resolve_user_ids(user_id): | |
| for activity in self.activity_by_user.get(resolved_id, []): | |
| key = (activity.get("kind"), activity.get("thread_id")) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| acts.append(activity) | |
| return {"results": acts, "count": len(acts)} | |
| def get_profile(self, user_id: str) -> dict[str, Any]: | |
| resolved_ids = self._resolve_user_ids(user_id) | |
| profile = next((self.profiles_by_user.get(candidate) for candidate in resolved_ids if self.profiles_by_user.get(candidate)), None) | |
| return {"result": profile, "found": profile is not None} | |
| def search_people(self, name: str | None = None, org: str | None = None) -> dict[str, Any]: | |
| results = self.views.profiles | |
| if name: | |
| name_query = str(name).lower() | |
| results = [ | |
| p | |
| for p in results | |
| if name_query in p["name"].lower() | |
| or name_query in p["user_id"].lower() | |
| or any(name_query in alias.lower() for alias in p.get("alias_ids", [])) | |
| ] | |
| if org: | |
| org_query = str(org).lower() | |
| normalized_org = self._normalize_lookup_token(org_query) | |
| results = [ | |
| p | |
| for p in results | |
| if org_query in p["org"].lower() | |
| or org_query in str(p.get("org_id", "")).lower() | |
| or (normalized_org and normalized_org in p["org"].lower()) | |
| ] | |
| return {"results": results[:20], "count": len(results)} | |
| def get_connections(self, user_id: str) -> dict[str, Any]: | |
| resolved_ids = self._resolve_user_ids(user_id) | |
| profile = next((self.profiles_by_user.get(candidate) for candidate in resolved_ids if self.profiles_by_user.get(candidate)), None) | |
| return {"results": profile["connections"] if profile else [], "count": len(profile["connections"]) if profile else 0} | |