rogermt commited on
Commit
1e5b424
·
verified ·
1 Parent(s): ebfc1c9

Upload own-solver/neurogolf_solver/data_loader.py

Browse files
own-solver/neurogolf_solver/data_loader.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Data loading utilities for ARC-AGI tasks."""
3
+
4
+ import json
5
+ import os
6
+ import numpy as np
7
+ from .constants import CH, GH, GW
8
+
9
+
10
+ def load_tasks_dir(data_dir, arcgen_dir=None):
11
+ """Load tasks from directory structure."""
12
+ files = sorted(f for f in os.listdir(data_dir) if f.endswith('.json'))
13
+ tasks = {}
14
+ for i, f in enumerate(files):
15
+ with open(os.path.join(data_dir, f)) as fh:
16
+ data = json.load(fh)
17
+ hex_id = f.replace('.json', '')
18
+ if arcgen_dir and os.path.exists(os.path.join(arcgen_dir, f)):
19
+ with open(os.path.join(arcgen_dir, f)) as fh:
20
+ arcgen_examples = json.load(fh)
21
+ if isinstance(arcgen_examples, list):
22
+ data['arc-gen'] = arcgen_examples
23
+ if 'arc-gen' not in data:
24
+ data['arc-gen'] = []
25
+ tasks[i + 1] = {'hex': hex_id, 'data': data}
26
+ return tasks
27
+
28
+
29
+ def load_tasks_kaggle(data_dir):
30
+ """Load tasks from Kaggle format."""
31
+ tasks = {}
32
+ for tn in range(1, 401):
33
+ path = os.path.join(data_dir, f"task{tn:03d}.json")
34
+ if os.path.exists(path):
35
+ with open(path) as f:
36
+ data = json.load(f)
37
+ if 'arc-gen' not in data:
38
+ data['arc-gen'] = []
39
+ tasks[tn] = {'hex': f'task{tn:03d}', 'data': data}
40
+ return tasks
41
+
42
+
43
+ def to_onehot(grid):
44
+ """Convert grid to one-hot encoding."""
45
+ arr = np.zeros((1, CH, GH, GW), dtype=np.float32)
46
+ for r, row in enumerate(grid):
47
+ for c, v in enumerate(row):
48
+ if r < GH and c < GW and 0 <= v < CH:
49
+ arr[0, v, r, c] = 1.0
50
+ return arr
51
+
52
+
53
+ def get_exs(td):
54
+ """Get examples as numpy arrays."""
55
+ return [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
56
+ for ex in td['train'] + td['test']]
57
+
58
+
59
+ def get_exs_for_fitting(td):
60
+ """Get examples for fitting with ARC-GEN augmentation."""
61
+ base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
62
+ for ex in td['train'] + td['test']]
63
+ if not base_exs:
64
+ return base_exs
65
+ base_shapes = {inp.shape for inp, _ in base_exs}
66
+ if len(base_shapes) != 1:
67
+ return base_exs
68
+ base_shape = list(base_shapes)[0]
69
+ ag_exs = []
70
+ for ex in td.get('arc-gen', []):
71
+ inp = np.array(ex['input'], dtype=np.int64)
72
+ out = np.array(ex['output'], dtype=np.int64)
73
+ if inp.shape == base_shape and out.shape == base_exs[0][1].shape:
74
+ ag_exs.append((inp, out))
75
+ return base_exs + ag_exs[:10]
76
+
77
+
78
+ def get_exs_for_fitting_variable(td):
79
+ """Get examples for variable-shape fitting."""
80
+ base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
81
+ for ex in td['train'] + td['test']]
82
+ ag_exs = []
83
+ for ex in td.get('arc-gen', []):
84
+ inp = np.array(ex['input'], dtype=np.int64)
85
+ out = np.array(ex['output'], dtype=np.int64)
86
+ if inp.shape == out.shape and inp.shape[0] <= 30 and inp.shape[1] <= 30:
87
+ ag_exs.append((inp, out))
88
+ return base_exs + ag_exs[:20]
89
+
90
+
91
+ def fixed_shapes(td):
92
+ """Check if task has fixed input/output shapes."""
93
+ shapes = set()
94
+ for inp, out in get_exs(td):
95
+ shapes.add((inp.shape, out.shape))
96
+ return list(shapes)[0] if len(shapes) == 1 else None