Spaces:
Sleeping
Sleeping
Commit ·
4de4725
1
Parent(s): 515f8c0
fixed tasks
Browse files- src/osint_env/data/generator.py +21 -7
- src/osint_env/env/environment.py +21 -1
- tests/test_environment.py +2 -0
src/osint_env/data/generator.py
CHANGED
|
@@ -354,6 +354,14 @@ class DatasetGenerator:
|
|
| 354 |
}.get(difficulty, "full_reward"),
|
| 355 |
}
|
| 356 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
def _infer_answer_from_question(self, question: str, graph: CanonicalGraph) -> str:
|
| 358 |
entities = self._extract_entity_tokens(question)
|
| 359 |
question_l = question.lower()
|
|
@@ -401,11 +409,8 @@ class DatasetGenerator:
|
|
| 401 |
tasks: list[TaskInstance] = []
|
| 402 |
for idx, question_spec in enumerate(self.config.seeding.seeded_questions):
|
| 403 |
answer = question_spec.answer or self._infer_answer_from_question(question_spec.question, graph)
|
| 404 |
-
metadata = dict(question_spec.metadata)
|
| 405 |
-
difficulty =
|
| 406 |
-
metadata["difficulty"] = difficulty
|
| 407 |
-
metadata.setdefault("grader", self._grader_for_difficulty(difficulty))
|
| 408 |
-
metadata.setdefault("scenario", self._task_type_for_difficulty(question_spec.task_type, difficulty))
|
| 409 |
if question_spec.supporting_edges:
|
| 410 |
support = [
|
| 411 |
Edge(src=e.src, rel=e.rel, dst=e.dst, confidence=float(e.confidence))
|
|
@@ -458,6 +463,7 @@ class DatasetGenerator:
|
|
| 458 |
question=q,
|
| 459 |
answer=a,
|
| 460 |
supporting_edges=support,
|
|
|
|
| 461 |
)
|
| 462 |
)
|
| 463 |
return tasks
|
|
@@ -534,7 +540,11 @@ class DatasetGenerator:
|
|
| 534 |
question=question,
|
| 535 |
answer=answer,
|
| 536 |
supporting_edges=support,
|
| 537 |
-
metadata=
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
)
|
| 539 |
)
|
| 540 |
if len(llm_tasks) >= count:
|
|
@@ -579,7 +589,11 @@ class DatasetGenerator:
|
|
| 579 |
question=question,
|
| 580 |
answer=answer,
|
| 581 |
supporting_edges=support,
|
| 582 |
-
metadata=
|
|
|
|
|
|
|
|
|
|
|
|
|
| 583 |
)
|
| 584 |
)
|
| 585 |
if len(llm_tasks) >= count:
|
|
|
|
| 354 |
}.get(difficulty, "full_reward"),
|
| 355 |
}
|
| 356 |
|
| 357 |
+
def _task_metadata(self, index: int, base_task_type: str, metadata: dict[str, Any] | None = None) -> dict[str, Any]:
|
| 358 |
+
out = dict(metadata or {})
|
| 359 |
+
difficulty = self._normalize_difficulty(out.get("difficulty", ""), index)
|
| 360 |
+
out["difficulty"] = difficulty
|
| 361 |
+
out.setdefault("grader", self._grader_for_difficulty(difficulty))
|
| 362 |
+
out.setdefault("scenario", self._task_type_for_difficulty(base_task_type, difficulty))
|
| 363 |
+
return out
|
| 364 |
+
|
| 365 |
def _infer_answer_from_question(self, question: str, graph: CanonicalGraph) -> str:
|
| 366 |
entities = self._extract_entity_tokens(question)
|
| 367 |
question_l = question.lower()
|
|
|
|
| 409 |
tasks: list[TaskInstance] = []
|
| 410 |
for idx, question_spec in enumerate(self.config.seeding.seeded_questions):
|
| 411 |
answer = question_spec.answer or self._infer_answer_from_question(question_spec.question, graph)
|
| 412 |
+
metadata = self._task_metadata(idx, question_spec.task_type, dict(question_spec.metadata))
|
| 413 |
+
difficulty = str(metadata.get("difficulty", "hard"))
|
|
|
|
|
|
|
|
|
|
| 414 |
if question_spec.supporting_edges:
|
| 415 |
support = [
|
| 416 |
Edge(src=e.src, rel=e.rel, dst=e.dst, confidence=float(e.confidence))
|
|
|
|
| 463 |
question=q,
|
| 464 |
answer=a,
|
| 465 |
supporting_edges=support,
|
| 466 |
+
metadata=self._task_metadata(start_idx + i, mode),
|
| 467 |
)
|
| 468 |
)
|
| 469 |
return tasks
|
|
|
|
| 540 |
question=question,
|
| 541 |
answer=answer,
|
| 542 |
supporting_edges=support,
|
| 543 |
+
metadata=self._task_metadata(
|
| 544 |
+
start_idx + len(llm_tasks),
|
| 545 |
+
task_type,
|
| 546 |
+
{"generated_by": "llm", "shared_context": True},
|
| 547 |
+
),
|
| 548 |
)
|
| 549 |
)
|
| 550 |
if len(llm_tasks) >= count:
|
|
|
|
| 589 |
question=question,
|
| 590 |
answer=answer,
|
| 591 |
supporting_edges=support,
|
| 592 |
+
metadata=self._task_metadata(
|
| 593 |
+
start_idx + len(llm_tasks),
|
| 594 |
+
task_type,
|
| 595 |
+
{"generated_by": "llm", "shared_context": True},
|
| 596 |
+
),
|
| 597 |
)
|
| 598 |
)
|
| 599 |
if len(llm_tasks) >= count:
|
src/osint_env/env/environment.py
CHANGED
|
@@ -216,11 +216,31 @@ class OSINTEnvironment(Env):
|
|
| 216 |
def _observation(self) -> Observation:
|
| 217 |
if self.state is None:
|
| 218 |
raise RuntimeError("State is not initialized.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
return Observation(
|
| 220 |
tool_outputs=self.state.tool_outputs[-5:],
|
| 221 |
graph_snapshot=self.memory_graph.to_snapshot(),
|
| 222 |
action_history=self.state.action_history[-10:],
|
| 223 |
-
task=
|
| 224 |
)
|
| 225 |
|
| 226 |
def _info(self) -> dict[str, Any]:
|
|
|
|
| 216 |
def _observation(self) -> Observation:
|
| 217 |
if self.state is None:
|
| 218 |
raise RuntimeError("State is not initialized.")
|
| 219 |
+
metadata = dict(self.state.task.metadata or {})
|
| 220 |
+
grader = metadata.get("grader") if isinstance(metadata.get("grader"), dict) else None
|
| 221 |
+
task_payload = {
|
| 222 |
+
"task_id": self.state.task.task_id,
|
| 223 |
+
"task_type": self.state.task.task_type,
|
| 224 |
+
"question": self.state.task.question,
|
| 225 |
+
"difficulty": self.state.difficulty,
|
| 226 |
+
"grader": (
|
| 227 |
+
dict(grader)
|
| 228 |
+
if grader is not None
|
| 229 |
+
else {
|
| 230 |
+
"type": "difficulty_exact_match",
|
| 231 |
+
"answer_type": "node_id",
|
| 232 |
+
"case_sensitive": True,
|
| 233 |
+
"reward_profile": self.state.difficulty,
|
| 234 |
+
}
|
| 235 |
+
),
|
| 236 |
+
}
|
| 237 |
+
if "scenario" in metadata:
|
| 238 |
+
task_payload["scenario"] = str(metadata.get("scenario", ""))
|
| 239 |
return Observation(
|
| 240 |
tool_outputs=self.state.tool_outputs[-5:],
|
| 241 |
graph_snapshot=self.memory_graph.to_snapshot(),
|
| 242 |
action_history=self.state.action_history[-10:],
|
| 243 |
+
task=task_payload,
|
| 244 |
)
|
| 245 |
|
| 246 |
def _info(self) -> dict[str, Any]:
|
tests/test_environment.py
CHANGED
|
@@ -6,6 +6,8 @@ def test_episode_flow():
|
|
| 6 |
env = OSINTEnvironment(EnvironmentConfig(max_steps=5, seed=5))
|
| 7 |
obs = env.reset()
|
| 8 |
assert "question" in obs.task
|
|
|
|
|
|
|
| 9 |
obs, r1, done, _ = env.step(Action(ActionType.CALL_TOOL, {"tool_name": "search_posts", "args": {"query": "Update"}}))
|
| 10 |
assert done is False
|
| 11 |
assert isinstance(r1, float)
|
|
|
|
| 6 |
env = OSINTEnvironment(EnvironmentConfig(max_steps=5, seed=5))
|
| 7 |
obs = env.reset()
|
| 8 |
assert "question" in obs.task
|
| 9 |
+
assert isinstance(obs.task.get("grader"), dict)
|
| 10 |
+
assert "type" in obs.task["grader"]
|
| 11 |
obs, r1, done, _ = env.step(Action(ActionType.CALL_TOOL, {"tool_name": "search_posts", "args": {"query": "Update"}}))
|
| 12 |
assert done is False
|
| 13 |
assert isinstance(r1, float)
|