File size: 3,354 Bytes
1e5b424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#!/usr/bin/env python3
"""Data loading utilities for ARC-AGI tasks."""

import json
import os
import numpy as np
from .constants import CH, GH, GW


def load_tasks_dir(data_dir, arcgen_dir=None):
    """Load tasks from directory structure."""
    files = sorted(f for f in os.listdir(data_dir) if f.endswith('.json'))
    tasks = {}
    for i, f in enumerate(files):
        with open(os.path.join(data_dir, f)) as fh:
            data = json.load(fh)
        hex_id = f.replace('.json', '')
        if arcgen_dir and os.path.exists(os.path.join(arcgen_dir, f)):
            with open(os.path.join(arcgen_dir, f)) as fh:
                arcgen_examples = json.load(fh)
            if isinstance(arcgen_examples, list):
                data['arc-gen'] = arcgen_examples
        if 'arc-gen' not in data:
            data['arc-gen'] = []
        tasks[i + 1] = {'hex': hex_id, 'data': data}
    return tasks


def load_tasks_kaggle(data_dir):
    """Load tasks from Kaggle format."""
    tasks = {}
    for tn in range(1, 401):
        path = os.path.join(data_dir, f"task{tn:03d}.json")
        if os.path.exists(path):
            with open(path) as f:
                data = json.load(f)
            if 'arc-gen' not in data:
                data['arc-gen'] = []
            tasks[tn] = {'hex': f'task{tn:03d}', 'data': data}
    return tasks


def to_onehot(grid):
    """Convert grid to one-hot encoding."""
    arr = np.zeros((1, CH, GH, GW), dtype=np.float32)
    for r, row in enumerate(grid):
        for c, v in enumerate(row):
            if r < GH and c < GW and 0 <= v < CH:
                arr[0, v, r, c] = 1.0
    return arr


def get_exs(td):
    """Get examples as numpy arrays."""
    return [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
            for ex in td['train'] + td['test']]


def get_exs_for_fitting(td):
    """Get examples for fitting with ARC-GEN augmentation."""
    base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
                for ex in td['train'] + td['test']]
    if not base_exs:
        return base_exs
    base_shapes = {inp.shape for inp, _ in base_exs}
    if len(base_shapes) != 1:
        return base_exs
    base_shape = list(base_shapes)[0]
    ag_exs = []
    for ex in td.get('arc-gen', []):
        inp = np.array(ex['input'], dtype=np.int64)
        out = np.array(ex['output'], dtype=np.int64)
        if inp.shape == base_shape and out.shape == base_exs[0][1].shape:
            ag_exs.append((inp, out))
    return base_exs + ag_exs[:10]


def get_exs_for_fitting_variable(td):
    """Get examples for variable-shape fitting."""
    base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
                for ex in td['train'] + td['test']]
    ag_exs = []
    for ex in td.get('arc-gen', []):
        inp = np.array(ex['input'], dtype=np.int64)
        out = np.array(ex['output'], dtype=np.int64)
        if inp.shape == out.shape and inp.shape[0] <= 30 and inp.shape[1] <= 30:
            ag_exs.append((inp, out))
    return base_exs + ag_exs[:20]


def fixed_shapes(td):
    """Check if task has fixed input/output shapes."""
    shapes = set()
    for inp, out in get_exs(td):
        shapes.add((inp.shape, out.shape))
    return list(shapes)[0] if len(shapes) == 1 else None