rogermt's picture
Upload own-solver/neurogolf_solver/data_loader.py
1e5b424 verified
#!/usr/bin/env python3
"""Data loading utilities for ARC-AGI tasks."""
import json
import os
import numpy as np
from .constants import CH, GH, GW
def load_tasks_dir(data_dir, arcgen_dir=None):
"""Load tasks from directory structure."""
files = sorted(f for f in os.listdir(data_dir) if f.endswith('.json'))
tasks = {}
for i, f in enumerate(files):
with open(os.path.join(data_dir, f)) as fh:
data = json.load(fh)
hex_id = f.replace('.json', '')
if arcgen_dir and os.path.exists(os.path.join(arcgen_dir, f)):
with open(os.path.join(arcgen_dir, f)) as fh:
arcgen_examples = json.load(fh)
if isinstance(arcgen_examples, list):
data['arc-gen'] = arcgen_examples
if 'arc-gen' not in data:
data['arc-gen'] = []
tasks[i + 1] = {'hex': hex_id, 'data': data}
return tasks
def load_tasks_kaggle(data_dir):
"""Load tasks from Kaggle format."""
tasks = {}
for tn in range(1, 401):
path = os.path.join(data_dir, f"task{tn:03d}.json")
if os.path.exists(path):
with open(path) as f:
data = json.load(f)
if 'arc-gen' not in data:
data['arc-gen'] = []
tasks[tn] = {'hex': f'task{tn:03d}', 'data': data}
return tasks
def to_onehot(grid):
"""Convert grid to one-hot encoding."""
arr = np.zeros((1, CH, GH, GW), dtype=np.float32)
for r, row in enumerate(grid):
for c, v in enumerate(row):
if r < GH and c < GW and 0 <= v < CH:
arr[0, v, r, c] = 1.0
return arr
def get_exs(td):
"""Get examples as numpy arrays."""
return [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
for ex in td['train'] + td['test']]
def get_exs_for_fitting(td):
"""Get examples for fitting with ARC-GEN augmentation."""
base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
for ex in td['train'] + td['test']]
if not base_exs:
return base_exs
base_shapes = {inp.shape for inp, _ in base_exs}
if len(base_shapes) != 1:
return base_exs
base_shape = list(base_shapes)[0]
ag_exs = []
for ex in td.get('arc-gen', []):
inp = np.array(ex['input'], dtype=np.int64)
out = np.array(ex['output'], dtype=np.int64)
if inp.shape == base_shape and out.shape == base_exs[0][1].shape:
ag_exs.append((inp, out))
return base_exs + ag_exs[:10]
def get_exs_for_fitting_variable(td):
"""Get examples for variable-shape fitting."""
base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
for ex in td['train'] + td['test']]
ag_exs = []
for ex in td.get('arc-gen', []):
inp = np.array(ex['input'], dtype=np.int64)
out = np.array(ex['output'], dtype=np.int64)
if inp.shape == out.shape and inp.shape[0] <= 30 and inp.shape[1] <= 30:
ag_exs.append((inp, out))
return base_exs + ag_exs[:20]
def fixed_shapes(td):
"""Check if task has fixed input/output shapes."""
shapes = set()
for inp, out in get_exs(td):
shapes.add((inp.shape, out.shape))
return list(shapes)[0] if len(shapes) == 1 else None