Siddeshwar1625 commited on
Commit
4de4725
·
1 Parent(s): 515f8c0

fixed tasks

Browse files
src/osint_env/data/generator.py CHANGED
@@ -354,6 +354,14 @@ class DatasetGenerator:
354
  }.get(difficulty, "full_reward"),
355
  }
356
 
 
 
 
 
 
 
 
 
357
  def _infer_answer_from_question(self, question: str, graph: CanonicalGraph) -> str:
358
  entities = self._extract_entity_tokens(question)
359
  question_l = question.lower()
@@ -401,11 +409,8 @@ class DatasetGenerator:
401
  tasks: list[TaskInstance] = []
402
  for idx, question_spec in enumerate(self.config.seeding.seeded_questions):
403
  answer = question_spec.answer or self._infer_answer_from_question(question_spec.question, graph)
404
- metadata = dict(question_spec.metadata)
405
- difficulty = self._normalize_difficulty(metadata.get("difficulty", ""), idx)
406
- metadata["difficulty"] = difficulty
407
- metadata.setdefault("grader", self._grader_for_difficulty(difficulty))
408
- metadata.setdefault("scenario", self._task_type_for_difficulty(question_spec.task_type, difficulty))
409
  if question_spec.supporting_edges:
410
  support = [
411
  Edge(src=e.src, rel=e.rel, dst=e.dst, confidence=float(e.confidence))
@@ -458,6 +463,7 @@ class DatasetGenerator:
458
  question=q,
459
  answer=a,
460
  supporting_edges=support,
 
461
  )
462
  )
463
  return tasks
@@ -534,7 +540,11 @@ class DatasetGenerator:
534
  question=question,
535
  answer=answer,
536
  supporting_edges=support,
537
- metadata={"generated_by": "llm", "shared_context": True},
 
 
 
 
538
  )
539
  )
540
  if len(llm_tasks) >= count:
@@ -579,7 +589,11 @@ class DatasetGenerator:
579
  question=question,
580
  answer=answer,
581
  supporting_edges=support,
582
- metadata={"generated_by": "llm", "shared_context": True},
 
 
 
 
583
  )
584
  )
585
  if len(llm_tasks) >= count:
 
354
  }.get(difficulty, "full_reward"),
355
  }
356
 
357
+ def _task_metadata(self, index: int, base_task_type: str, metadata: dict[str, Any] | None = None) -> dict[str, Any]:
358
+ out = dict(metadata or {})
359
+ difficulty = self._normalize_difficulty(out.get("difficulty", ""), index)
360
+ out["difficulty"] = difficulty
361
+ out.setdefault("grader", self._grader_for_difficulty(difficulty))
362
+ out.setdefault("scenario", self._task_type_for_difficulty(base_task_type, difficulty))
363
+ return out
364
+
365
  def _infer_answer_from_question(self, question: str, graph: CanonicalGraph) -> str:
366
  entities = self._extract_entity_tokens(question)
367
  question_l = question.lower()
 
409
  tasks: list[TaskInstance] = []
410
  for idx, question_spec in enumerate(self.config.seeding.seeded_questions):
411
  answer = question_spec.answer or self._infer_answer_from_question(question_spec.question, graph)
412
+ metadata = self._task_metadata(idx, question_spec.task_type, dict(question_spec.metadata))
413
+ difficulty = str(metadata.get("difficulty", "hard"))
 
 
 
414
  if question_spec.supporting_edges:
415
  support = [
416
  Edge(src=e.src, rel=e.rel, dst=e.dst, confidence=float(e.confidence))
 
463
  question=q,
464
  answer=a,
465
  supporting_edges=support,
466
+ metadata=self._task_metadata(start_idx + i, mode),
467
  )
468
  )
469
  return tasks
 
540
  question=question,
541
  answer=answer,
542
  supporting_edges=support,
543
+ metadata=self._task_metadata(
544
+ start_idx + len(llm_tasks),
545
+ task_type,
546
+ {"generated_by": "llm", "shared_context": True},
547
+ ),
548
  )
549
  )
550
  if len(llm_tasks) >= count:
 
589
  question=question,
590
  answer=answer,
591
  supporting_edges=support,
592
+ metadata=self._task_metadata(
593
+ start_idx + len(llm_tasks),
594
+ task_type,
595
+ {"generated_by": "llm", "shared_context": True},
596
+ ),
597
  )
598
  )
599
  if len(llm_tasks) >= count:
src/osint_env/env/environment.py CHANGED
@@ -216,11 +216,31 @@ class OSINTEnvironment(Env):
216
  def _observation(self) -> Observation:
217
  if self.state is None:
218
  raise RuntimeError("State is not initialized.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  return Observation(
220
  tool_outputs=self.state.tool_outputs[-5:],
221
  graph_snapshot=self.memory_graph.to_snapshot(),
222
  action_history=self.state.action_history[-10:],
223
- task={"task_id": self.state.task.task_id, "task_type": self.state.task.task_type, "question": self.state.task.question},
224
  )
225
 
226
  def _info(self) -> dict[str, Any]:
 
216
  def _observation(self) -> Observation:
217
  if self.state is None:
218
  raise RuntimeError("State is not initialized.")
219
+ metadata = dict(self.state.task.metadata or {})
220
+ grader = metadata.get("grader") if isinstance(metadata.get("grader"), dict) else None
221
+ task_payload = {
222
+ "task_id": self.state.task.task_id,
223
+ "task_type": self.state.task.task_type,
224
+ "question": self.state.task.question,
225
+ "difficulty": self.state.difficulty,
226
+ "grader": (
227
+ dict(grader)
228
+ if grader is not None
229
+ else {
230
+ "type": "difficulty_exact_match",
231
+ "answer_type": "node_id",
232
+ "case_sensitive": True,
233
+ "reward_profile": self.state.difficulty,
234
+ }
235
+ ),
236
+ }
237
+ if "scenario" in metadata:
238
+ task_payload["scenario"] = str(metadata.get("scenario", ""))
239
  return Observation(
240
  tool_outputs=self.state.tool_outputs[-5:],
241
  graph_snapshot=self.memory_graph.to_snapshot(),
242
  action_history=self.state.action_history[-10:],
243
+ task=task_payload,
244
  )
245
 
246
  def _info(self) -> dict[str, Any]:
tests/test_environment.py CHANGED
@@ -6,6 +6,8 @@ def test_episode_flow():
6
  env = OSINTEnvironment(EnvironmentConfig(max_steps=5, seed=5))
7
  obs = env.reset()
8
  assert "question" in obs.task
 
 
9
  obs, r1, done, _ = env.step(Action(ActionType.CALL_TOOL, {"tool_name": "search_posts", "args": {"query": "Update"}}))
10
  assert done is False
11
  assert isinstance(r1, float)
 
6
  env = OSINTEnvironment(EnvironmentConfig(max_steps=5, seed=5))
7
  obs = env.reset()
8
  assert "question" in obs.task
9
+ assert isinstance(obs.task.get("grader"), dict)
10
+ assert "type" in obs.task["grader"]
11
  obs, r1, done, _ = env.step(Action(ActionType.CALL_TOOL, {"tool_name": "search_posts", "args": {"query": "Update"}}))
12
  assert done is False
13
  assert isinstance(r1, float)