| |
| """Data loading utilities for ARC-AGI tasks.""" |
|
|
| import json |
| import os |
| import numpy as np |
| from .constants import CH, GH, GW |
|
|
|
|
| def load_tasks_dir(data_dir, arcgen_dir=None): |
| """Load tasks from directory structure.""" |
| files = sorted(f for f in os.listdir(data_dir) if f.endswith('.json')) |
| tasks = {} |
| for i, f in enumerate(files): |
| with open(os.path.join(data_dir, f)) as fh: |
| data = json.load(fh) |
| hex_id = f.replace('.json', '') |
| if arcgen_dir and os.path.exists(os.path.join(arcgen_dir, f)): |
| with open(os.path.join(arcgen_dir, f)) as fh: |
| arcgen_examples = json.load(fh) |
| if isinstance(arcgen_examples, list): |
| data['arc-gen'] = arcgen_examples |
| if 'arc-gen' not in data: |
| data['arc-gen'] = [] |
| tasks[i + 1] = {'hex': hex_id, 'data': data} |
| return tasks |
|
|
|
|
| def load_tasks_kaggle(data_dir): |
| """Load tasks from Kaggle format.""" |
| tasks = {} |
| for tn in range(1, 401): |
| path = os.path.join(data_dir, f"task{tn:03d}.json") |
| if os.path.exists(path): |
| with open(path) as f: |
| data = json.load(f) |
| if 'arc-gen' not in data: |
| data['arc-gen'] = [] |
| tasks[tn] = {'hex': f'task{tn:03d}', 'data': data} |
| return tasks |
|
|
|
|
| def to_onehot(grid): |
| """Convert grid to one-hot encoding.""" |
| arr = np.zeros((1, CH, GH, GW), dtype=np.float32) |
| for r, row in enumerate(grid): |
| for c, v in enumerate(row): |
| if r < GH and c < GW and 0 <= v < CH: |
| arr[0, v, r, c] = 1.0 |
| return arr |
|
|
|
|
| def get_exs(td): |
| """Get examples as numpy arrays.""" |
| return [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64)) |
| for ex in td['train'] + td['test']] |
|
|
|
|
| def get_exs_for_fitting(td): |
| """Get examples for fitting with ARC-GEN augmentation.""" |
| base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64)) |
| for ex in td['train'] + td['test']] |
| if not base_exs: |
| return base_exs |
| base_shapes = {inp.shape for inp, _ in base_exs} |
| if len(base_shapes) != 1: |
| return base_exs |
| base_shape = list(base_shapes)[0] |
| ag_exs = [] |
| for ex in td.get('arc-gen', []): |
| inp = np.array(ex['input'], dtype=np.int64) |
| out = np.array(ex['output'], dtype=np.int64) |
| if inp.shape == base_shape and out.shape == base_exs[0][1].shape: |
| ag_exs.append((inp, out)) |
| return base_exs + ag_exs[:10] |
|
|
|
|
| def get_exs_for_fitting_variable(td): |
| """Get examples for variable-shape fitting.""" |
| base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64)) |
| for ex in td['train'] + td['test']] |
| ag_exs = [] |
| for ex in td.get('arc-gen', []): |
| inp = np.array(ex['input'], dtype=np.int64) |
| out = np.array(ex['output'], dtype=np.int64) |
| if inp.shape == out.shape and inp.shape[0] <= 30 and inp.shape[1] <= 30: |
| ag_exs.append((inp, out)) |
| return base_exs + ag_exs[:20] |
|
|
|
|
| def fixed_shapes(td): |
| """Check if task has fixed input/output shapes.""" |
| shapes = set() |
| for inp, out in get_exs(td): |
| shapes.add((inp.shape, out.shape)) |
| return list(shapes)[0] if len(shapes) == 1 else None |
|
|