File size: 4,177 Bytes
2312199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""Task loader service — reads task definitions from JSON files."""

from __future__ import annotations

import json
import logging
from pathlib import Path

from lexenvs.config import get_settings, project_root
from lexenvs.schemas.task import TaskDefinition
from lexenvs.services.kb_filter_service import filter_allowlist, filter_knowledge_base

logger = logging.getLogger(__name__)


class TaskLoaderService:
    """Loads task definitions from the data/tasks/ directory.

    Resolves ``knowledge_base_ref`` and ``system_prompt_ref`` in each task's
    prompt by reading the referenced files from the data directory.
    """

    def __init__(self, tasks_dir: str | Path | None = None) -> None:
        if tasks_dir is not None:
            self._tasks_dir = Path(tasks_dir)
        else:
            settings = get_settings()
            self._tasks_dir = project_root() / settings.tasks_data_dir

        self._data_dir = self._tasks_dir.parent
        self._file_cache: dict[str, str] = {}
        self._tasks: dict[str, TaskDefinition] = {}
        self._load_tasks()

    def _read_data_file(self, ref: str) -> str:
        """Read and cache a file from the data directory."""
        if ref in self._file_cache:
            return self._file_cache[ref]

        file_path = (self._data_dir / ref).resolve()
        if not file_path.is_relative_to(self._data_dir.resolve()):
            logger.error("File ref escapes data directory: %s", ref)
            raise ValueError(f"Invalid file ref: {ref!r}")

        if not file_path.exists():
            logger.warning("Data file not found: %s", file_path)
            return ""

        content = file_path.read_text(encoding="utf-8")
        self._file_cache[ref] = content
        logger.info("Loaded data file: %s (%d chars)", ref, len(content))
        return content

    def _load_tasks(self) -> None:
        """Load all task JSON files from the tasks directory."""
        if not self._tasks_dir.exists():
            logger.warning("Tasks directory not found: %s", self._tasks_dir)
            return

        for task_file in sorted(self._tasks_dir.glob("*.json")):
            try:
                raw = json.loads(task_file.read_text(encoding="utf-8"))
                prompt = raw.get("prompt", {})

                # Resolve system_prompt_ref → system
                sys_ref = prompt.get("system_prompt_ref")
                if sys_ref and not prompt.get("system"):
                    prompt["system"] = self._read_data_file(sys_ref)

                # Resolve knowledge_base_ref → context
                kb_ref = prompt.get("knowledge_base_ref")
                if kb_ref and not prompt.get("context"):
                    full_kb = self._read_data_file(kb_ref)
                    kb_filter = prompt.get("kb_filter")
                    if kb_filter:
                        prompt["context"] = filter_knowledge_base(full_kb, kb_filter)
                        prompt["system"] = filter_allowlist(prompt.get("system", ""), kb_filter)
                    else:
                        prompt["context"] = full_kb

                task = TaskDefinition.model_validate(raw)
                self._tasks[task.task_id] = task
                logger.info("Loaded task: %s", task.task_id)
            except (json.JSONDecodeError, ValueError) as e:
                logger.error("Invalid task file %s: %s", task_file.name, e)
            except OSError as e:
                logger.error("Cannot read task file %s: %s", task_file.name, e)

        logger.info("Loaded %d tasks from %s", len(self._tasks), self._tasks_dir)

    def get_task(self, task_id: str) -> TaskDefinition | None:
        """Get a single task definition by ID."""
        return self._tasks.get(task_id)

    def list_tasks(self) -> list[TaskDefinition]:
        """Return all loaded task definitions."""
        return list(self._tasks.values())

    def reload(self) -> None:
        """Reload tasks from disk."""
        self._tasks.clear()
        self._file_cache.clear()
        self._load_tasks()


def create_task_loader_service() -> TaskLoaderService:
    """Factory for svcs registration."""
    return TaskLoaderService()