File size: 7,485 Bytes
d814291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from __future__ import annotations

from collections import defaultdict
from typing import Any

from osint_env.data.generator import PlatformViews


class ToolRegistry:
    def __init__(self, views: PlatformViews):
        self.views = views
        self.alias_lookup = dict(getattr(views, "alias_lookup", {}))
        self._index()

    @staticmethod
    def _normalize_lookup_token(value: str) -> str:
        token = str(value or "").strip().lower()
        for prefix in ("org_", "loc_", "event_", "post_", "thr_", "thread_", "alias_", "user_"):
            if token.startswith(prefix):
                token = token[len(prefix) :]
                break
        return token.replace("_", " ")

    def _resolve_user_ids(self, user_id: str) -> list[str]:
        user_id = str(user_id or "").strip()
        if not user_id:
            return []
        resolved = [user_id]
        canonical = self.alias_lookup.get(user_id)
        if canonical and canonical not in resolved:
            resolved.append(canonical)
        for alias_id, owner in self.alias_lookup.items():
            if owner == user_id and alias_id not in resolved:
                resolved.append(alias_id)
        return resolved

    def _index(self) -> None:
        self.posts_by_user: dict[str, list[dict[str, Any]]] = defaultdict(list)
        self.mentions_by_user: dict[str, list[dict[str, Any]]] = defaultdict(list)
        self.posts_by_id = {post["post_id"]: post for post in self.views.microblog_posts}
        for post in self.views.microblog_posts:
            self.posts_by_user[post["user_id"]].append(post)
            canonical_user = post.get("canonical_user")
            if canonical_user:
                self.posts_by_user[canonical_user].append(post)
            for m in post.get("mentions", []):
                self.mentions_by_user[m].append(post)

        self.threads_by_id = {t["thread_id"]: t for t in self.views.forum_threads}
        self.activity_by_user: dict[str, list[dict[str, Any]]] = defaultdict(list)
        for thread in self.views.forum_threads:
            self.activity_by_user[thread["author_id"]].append({"kind": "thread", "thread_id": thread["thread_id"]})
            for c in thread.get("comments", []):
                self.activity_by_user[c["user_id"]].append({"kind": "comment", "thread_id": thread["thread_id"]})

        self.profiles_by_user = {p["user_id"]: p for p in self.views.profiles}

    def call(self, tool_name: str, args: dict[str, Any]) -> dict[str, Any]:
        fn = getattr(self, tool_name, None)
        if not fn:
            raise ValueError(f"Unknown tool: {tool_name}")
        return fn(**args)

    def search_posts(self, query: str, time_range: tuple[int, int] | None = None) -> dict[str, Any]:
        start, end = time_range or (0, 10**9)
        needle = str(query or "").lower()
        results = [
            p
            for p in self.views.microblog_posts
            if start <= p["timestamp"] <= end
            and (
                needle in p["text"].lower()
                or needle in str(p.get("post_id", "")).lower()
                or needle in str(p.get("user_id", "")).lower()
                or needle in str(p.get("canonical_user", "")).lower()
                or any(needle in str(ref).lower() for ref in p.get("references", []))
                or any(needle in str(ref).lower() for ref in p.get("reference_names", []))
            )
        ]
        return {"results": results[:20], "count": len(results)}

    def get_post(self, post_id: str) -> dict[str, Any]:
        post = self.posts_by_id.get(post_id)
        return {"result": post, "found": post is not None}

    def get_user_posts(self, user_id: str) -> dict[str, Any]:
        results: list[dict[str, Any]] = []
        seen_post_ids: set[str] = set()
        for resolved_id in self._resolve_user_ids(user_id):
            for post in self.posts_by_user.get(resolved_id, []):
                post_id = str(post.get("post_id", ""))
                if post_id in seen_post_ids:
                    continue
                seen_post_ids.add(post_id)
                results.append(post)
        return {"results": results, "count": len(results)}

    def get_mentions(self, user_id: str) -> dict[str, Any]:
        results: list[dict[str, Any]] = []
        seen_post_ids: set[str] = set()
        for resolved_id in self._resolve_user_ids(user_id):
            for post in self.mentions_by_user.get(resolved_id, []):
                post_id = str(post.get("post_id", ""))
                if post_id in seen_post_ids:
                    continue
                seen_post_ids.add(post_id)
                results.append(post)
        return {"results": results, "count": len(results)}

    def search_threads(self, topic: str) -> dict[str, Any]:
        needle = str(topic or "").strip().lower()
        results = [
            t
            for t in self.views.forum_threads
            if t["topic"] == topic
            or needle in str(t.get("thread_id", "")).lower()
            or needle in str(t.get("title", "")).lower()
        ]
        return {"results": results[:20], "count": len(results)}

    def get_thread(self, thread_id: str) -> dict[str, Any]:
        thread = self.threads_by_id.get(thread_id)
        return {"result": thread, "found": thread is not None}

    def get_user_activity(self, user_id: str) -> dict[str, Any]:
        acts: list[dict[str, Any]] = []
        seen = set()
        for resolved_id in self._resolve_user_ids(user_id):
            for activity in self.activity_by_user.get(resolved_id, []):
                key = (activity.get("kind"), activity.get("thread_id"))
                if key in seen:
                    continue
                seen.add(key)
                acts.append(activity)
        return {"results": acts, "count": len(acts)}

    def get_profile(self, user_id: str) -> dict[str, Any]:
        resolved_ids = self._resolve_user_ids(user_id)
        profile = next((self.profiles_by_user.get(candidate) for candidate in resolved_ids if self.profiles_by_user.get(candidate)), None)
        return {"result": profile, "found": profile is not None}

    def search_people(self, name: str | None = None, org: str | None = None) -> dict[str, Any]:
        results = self.views.profiles
        if name:
            name_query = str(name).lower()
            results = [
                p
                for p in results
                if name_query in p["name"].lower()
                or name_query in p["user_id"].lower()
                or any(name_query in alias.lower() for alias in p.get("alias_ids", []))
            ]
        if org:
            org_query = str(org).lower()
            normalized_org = self._normalize_lookup_token(org_query)
            results = [
                p
                for p in results
                if org_query in p["org"].lower()
                or org_query in str(p.get("org_id", "")).lower()
                or (normalized_org and normalized_org in p["org"].lower())
            ]
        return {"results": results[:20], "count": len(results)}

    def get_connections(self, user_id: str) -> dict[str, Any]:
        resolved_ids = self._resolve_user_ids(user_id)
        profile = next((self.profiles_by_user.get(candidate) for candidate in resolved_ids if self.profiles_by_user.get(candidate)), None)
        return {"results": profile["connections"] if profile else [], "count": len(profile["connections"]) if profile else 0}