#!/usr/bin/env python3 """Data loading utilities for ARC-AGI tasks.""" import json import os import numpy as np from .constants import CH, GH, GW def load_tasks_dir(data_dir, arcgen_dir=None): """Load tasks from directory structure.""" files = sorted(f for f in os.listdir(data_dir) if f.endswith('.json')) tasks = {} for i, f in enumerate(files): with open(os.path.join(data_dir, f)) as fh: data = json.load(fh) hex_id = f.replace('.json', '') if arcgen_dir and os.path.exists(os.path.join(arcgen_dir, f)): with open(os.path.join(arcgen_dir, f)) as fh: arcgen_examples = json.load(fh) if isinstance(arcgen_examples, list): data['arc-gen'] = arcgen_examples if 'arc-gen' not in data: data['arc-gen'] = [] tasks[i + 1] = {'hex': hex_id, 'data': data} return tasks def load_tasks_kaggle(data_dir): """Load tasks from Kaggle format.""" tasks = {} for tn in range(1, 401): path = os.path.join(data_dir, f"task{tn:03d}.json") if os.path.exists(path): with open(path) as f: data = json.load(f) if 'arc-gen' not in data: data['arc-gen'] = [] tasks[tn] = {'hex': f'task{tn:03d}', 'data': data} return tasks def to_onehot(grid): """Convert grid to one-hot encoding.""" arr = np.zeros((1, CH, GH, GW), dtype=np.float32) for r, row in enumerate(grid): for c, v in enumerate(row): if r < GH and c < GW and 0 <= v < CH: arr[0, v, r, c] = 1.0 return arr def get_exs(td): """Get examples as numpy arrays.""" return [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64)) for ex in td['train'] + td['test']] def get_exs_for_fitting(td): """Get examples for fitting with ARC-GEN augmentation.""" base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64)) for ex in td['train'] + td['test']] if not base_exs: return base_exs base_shapes = {inp.shape for inp, _ in base_exs} if len(base_shapes) != 1: return base_exs base_shape = list(base_shapes)[0] ag_exs = [] for ex in td.get('arc-gen', []): inp = np.array(ex['input'], dtype=np.int64) out = np.array(ex['output'], dtype=np.int64) if inp.shape == base_shape and out.shape == base_exs[0][1].shape: ag_exs.append((inp, out)) return base_exs + ag_exs[:10] def get_exs_for_fitting_variable(td): """Get examples for variable-shape fitting.""" base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64)) for ex in td['train'] + td['test']] ag_exs = [] for ex in td.get('arc-gen', []): inp = np.array(ex['input'], dtype=np.int64) out = np.array(ex['output'], dtype=np.int64) if inp.shape == out.shape and inp.shape[0] <= 30 and inp.shape[1] <= 30: ag_exs.append((inp, out)) return base_exs + ag_exs[:20] def fixed_shapes(td): """Check if task has fixed input/output shapes.""" shapes = set() for inp, out in get_exs(td): shapes.add((inp.shape, out.shape)) return list(shapes)[0] if len(shapes) == 1 else None