KnightBlade commited on
Commit
aa24d1e
·
1 Parent(s): c78c2fe

feat: Add Hugging Face dataset dynamic loading to environment reset

Browse files
server/data_wrangler_environment.py CHANGED
@@ -26,6 +26,39 @@ class DataWranglerEnvironment(Environment):
26
  def _initialize_task(self):
27
  self.df = pd.DataFrame()
28
  self.target_df = pd.DataFrame()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  if self.task_level == 1:
30
  # Easy: Just drop a column and rename one
31
  self.df = pd.DataFrame({
 
26
  def _initialize_task(self):
27
  self.df = pd.DataFrame()
28
  self.target_df = pd.DataFrame()
29
+
30
+ # Priority 5 - Dynamic Hugging Face or CSV Datasets
31
+ # If the user defines an external dataset via env var, load that instead.
32
+ dataset_source = os.environ.get("DATASET_SOURCE")
33
+ target_source = os.environ.get("TARGET_SOURCE")
34
+
35
+ if dataset_source:
36
+ if str(dataset_source).endswith(".csv"):
37
+ self.df = pd.read_csv(dataset_source)
38
+ elif str(dataset_source).endswith(".parquet"):
39
+ self.df = pd.read_parquet(dataset_source)
40
+ else:
41
+ from datasets import load_dataset
42
+ # Fallback to Hugging Face Hub (e.g. "scikit-learn/titanic", "argilla/news-summary")
43
+ # We grab the 'train' split by default and convert it to pandas
44
+ hf_data = load_dataset(dataset_source, split="train")
45
+ self.df = hf_data.to_pandas()
46
+
47
+ if target_source:
48
+ if str(target_source).endswith(".csv"):
49
+ self.target_df = pd.read_csv(target_source)
50
+ elif str(target_source).endswith(".parquet"):
51
+ self.target_df = pd.read_parquet(target_source)
52
+ else:
53
+ from datasets import load_dataset
54
+ hf_target = load_dataset(target_source, split="train")
55
+ self.target_df = hf_target.to_pandas()
56
+ else:
57
+ # If there's no target provided, we force the LLM to simply drop all rows with missing values
58
+ # as a baseline goal for Dynamic runs to prevent graded failures on unstructured tests
59
+ self.target_df = self.df.dropna()
60
+ return
61
+
62
  if self.task_level == 1:
63
  # Easy: Just drop a column and rename one
64
  self.df = pd.DataFrame({
server/requirements.txt CHANGED
@@ -6,3 +6,5 @@ uvicorn>=0.24.0
6
 
7
  pandas>=2.0.0
8
  openai>=1.0.0
 
 
 
6
 
7
  pandas>=2.0.0
8
  openai>=1.0.0
9
+ datasets>=2.14.0
10
+ pyarrow>=13.0.0