Jayant-Kernel commited on
fix: download datasets from GitHub at runtime instead of relying on package data
Browse files
train.py
CHANGED
|
@@ -58,9 +58,21 @@ from deceit_env.server.grader import Grader
|
|
| 58 |
from deceit_env.models import DeceitAction
|
| 59 |
import deceit_env as _pkg
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
_grader = Grader(cache_path="/tmp/deceit_grader_cache.json",
|
| 62 |
openai_api_key=os.environ["OPENAI_API_KEY"])
|
| 63 |
-
_env = DeceitEnvironment(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
_env_lock = threading.Lock()
|
| 65 |
|
| 66 |
# Parser
|
|
@@ -115,10 +127,8 @@ def reward_fn(completions, prompts=None, **kwargs):
|
|
| 115 |
return rewards
|
| 116 |
|
| 117 |
# Dataset
|
| 118 |
-
import deceit_env as _de
|
| 119 |
-
data_path = _pathlib.Path(_de.__file__).parent / "data" / "level1.jsonl"
|
| 120 |
questions = []
|
| 121 |
-
with open(
|
| 122 |
for line in f:
|
| 123 |
line = line.strip()
|
| 124 |
if line:
|
|
|
|
| 58 |
from deceit_env.models import DeceitAction
|
| 59 |
import deceit_env as _pkg
|
| 60 |
|
| 61 |
+
# Download datasets from GitHub
|
| 62 |
+
import urllib.request as _ur
|
| 63 |
+
_RAW = "https://raw.githubusercontent.com/Jayant-kernel/DECEIT-the-ai-truth-environment-/main/src/deceit_env/data"
|
| 64 |
+
for _fname in ["level1.jsonl", "level2.jsonl", "level3.jsonl"]:
|
| 65 |
+
_ur.urlretrieve(f"{_RAW}/{_fname}", f"/tmp/{_fname}")
|
| 66 |
+
print("Datasets downloaded.")
|
| 67 |
+
|
| 68 |
_grader = Grader(cache_path="/tmp/deceit_grader_cache.json",
|
| 69 |
openai_api_key=os.environ["OPENAI_API_KEY"])
|
| 70 |
+
_env = DeceitEnvironment(
|
| 71 |
+
dataset_path="/tmp/level1.jsonl",
|
| 72 |
+
level2_dataset_path="/tmp/level2.jsonl",
|
| 73 |
+
level3_dataset_path="/tmp/level3.jsonl",
|
| 74 |
+
grader=_grader,
|
| 75 |
+
)
|
| 76 |
_env_lock = threading.Lock()
|
| 77 |
|
| 78 |
# Parser
|
|
|
|
| 127 |
return rewards
|
| 128 |
|
| 129 |
# Dataset
|
|
|
|
|
|
|
| 130 |
questions = []
|
| 131 |
+
with open("/tmp/level1.jsonl") as f:
|
| 132 |
for line in f:
|
| 133 |
line = line.strip()
|
| 134 |
if line:
|