File size: 9,048 Bytes
8ed167c
b81d3dd
f65aaaf
a8989f9
f65aaaf
b81d3dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f65aaaf
f0eee3f
 
 
 
f65aaaf
f0eee3f
 
b81d3dd
 
f0eee3f
b81d3dd
f0eee3f
 
 
 
f65aaaf
f0eee3f
 
 
a8989f9
 
 
 
d5f5654
f0eee3f
 
 
 
 
 
 
 
71ddcd2
d5f5654
 
 
 
 
 
07c18c7
d5f5654
07c18c7
d5f5654
 
07c18c7
d5f5654
 
2060674
 
d5f5654
 
 
 
2060674
 
 
e76200d
 
 
 
07fd3b8
757c484
e76200d
2060674
d5f5654
 
07c18c7
2060674
d5f5654
 
07c18c7
2060674
6ec1943
b218e8e
6ec1943
e360100
 
 
6ec1943
 
 
d5f5654
 
 
 
 
 
 
 
a5b79af
d5f5654
 
 
 
 
b218e8e
 
 
 
6ec1943
 
 
 
07c18c7
6ec1943
d5f5654
6ec1943
 
 
d5f5654
 
6ec1943
 
 
 
 
d5f5654
a5b79af
d5f5654
 
07c18c7
a5b79af
07c18c7
f0eee3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
757c484
f0eee3f
 
 
07c18c7
 
 
 
 
 
 
 
 
 
6ec1943
 
 
 
 
 
 
07c18c7
6ec1943
 
 
 
 
 
07c18c7
 
 
 
 
 
 
 
 
6ec1943
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import mlcroissant._src.operation_graph.operations.download as dl_mod
import requests as _requests_mod
import requests
import os

_SERVER_HF_TOKEN = os.environ.get("HF_TOKEN")
print("[DEBUG] HF_TOKEN is", "set" if _SERVER_HF_TOKEN else "missing")

# _active_token holds the HF token to use for the current validation request.
# It defaults to the server-level HF_TOKEN but can be overridden per-request
# via set_active_token() so that logged-in users' own tokens are used instead.
_active_token: dict = {"token": _SERVER_HF_TOKEN}


def set_active_token(token: str | None) -> None:
    """Set the HF token to use for the current validation request."""
    _active_token["token"] = token if token else _SERVER_HF_TOKEN


def clear_active_token() -> None:
    """Reset the HF token back to the server-level default."""
    _active_token["token"] = _SERVER_HF_TOKEN


# Patch requests.Session.send to fail immediately on HTTP 429 instead of
# letting mlcroissant / fsspec / huggingface_hub retry silently for minutes.
_orig_session_send = _requests_mod.Session.send

def _rate_limit_aware_send(self, request, **kwargs):
    response = _orig_session_send(self, request, **kwargs)
    if response.status_code == 429:
        retry_after = response.headers.get("Retry-After", "unknown")
        raise _requests_mod.exceptions.HTTPError(
            f"HTTP 429 Too Many Requests for {request.url}. "
            f"Retry-After: {retry_after}s. "
            "You are being rate limited. Log in with your Hugging Face account to avoid this.",
            response=response,
        )
    return response

_requests_mod.Session.send = _rate_limit_aware_send


# Only send HF credentials when downloading from huggingface.co.
# The default get_basic_auth_from_env() applies auth to ALL URLs, which
# causes non-HF hosts (e.g. OpenML) to return 400 Bad Request.
_orig_download_from_http = dl_mod.Download._download_from_http

def _hf_aware_download(self, filepath):
    url = self.node.content_url or ""
    token = _active_token["token"]
    if token and "huggingface.co" in url:
        os.environ["CROISSANT_BASIC_AUTH_USERNAME"] = "hf_user"
        os.environ["CROISSANT_BASIC_AUTH_PASSWORD"] = token
    else:
        os.environ.pop("CROISSANT_BASIC_AUTH_USERNAME", None)
        os.environ.pop("CROISSANT_BASIC_AUTH_PASSWORD", None)
    return _orig_download_from_http(self, filepath)

dl_mod.Download._download_from_http = _hf_aware_download

import logging
import mlcroissant as mlc
import func_timeout
import json
import traceback

# Suppress noisy mlcroissant pattern-matching warnings
logging.getLogger("root").addFilter(
    lambda r: "Could not match" not in r.getMessage()
)
logging.getLogger().addFilter(
    lambda r: "Could not match" not in r.getMessage()
)

WAIT_TIME = 10 * 60  # seconds

def validate_json(file_path):
    """Validate that the file is proper JSON."""
    try:
        with open(file_path, 'r') as f:
            json_data = json.load(f)
        return True, "The file is valid JSON.", json_data
    except json.JSONDecodeError as e:
        error_message = f"Invalid JSON format: {str(e)}"
        return False, error_message, None
    except Exception as e:
        error_message = f"Error reading file: {str(e)}"
        return False, error_message, None

REQUIRED_SCHEMA_FIELDS = ["license"]

def validate_croissant(json_data):
    """Validate that the JSON follows Croissant schema."""
    try:
        dataset = mlc.Dataset(jsonld=json_data)
        missing = [f for f in REQUIRED_SCHEMA_FIELDS if not json_data.get(f)]
        if missing:
            return True, (
                "The `license` field is missing. This is required for NeurIPS dataset submissions. "
                "Please add a `license` field to your Croissant file with the name or URL of the licence governing your dataset. "
                "Where possible, use <a href='https://www.kaggle.com/discussions/getting-started/116476' target='_blank'>open licenses</a> that "
                "allow reuse and reproducibility. However, when the dataset contains sensitive data or stricter licensing is unavoidable, "
                "please select an appropriate license that is as open as possible given the constraints. "
                "You can use our <a href='https://huggingface.co/spaces/JoaquinVanschoren/croissant-rai-checker' target='_blank'>online RAI editing tool</a> to fill this information."
            ), "error"
        return True, "The dataset passes Croissant validation.", "pass"
    except mlc.ValidationError as e:
        error_details = traceback.format_exc()
        error_message = f"Validation failed: {str(e)}\n\n{error_details}"
        return False, error_message, "error"
    except Exception as e:
        error_details = traceback.format_exc()
        error_message = f"Unexpected error during validation: {str(e)}\n\n{error_details}"
        return False, error_message, "error"
    
def try_generate_record(record_collection):
    try:
        for i, record in enumerate(record_collection):
            if i == 0:
                break
        return "success"
    except Exception as e:
        return e

def validate_records(json_data):
    """Validate that records can be generated within the time limit."""
    try:
        dataset = mlc.Dataset(jsonld=json_data)
        record_sets = dataset.metadata.record_sets
        
        if not record_sets:
            return True, "No record sets found to validate.", "pass"
        
        results = []
        
        for record_set in record_sets:
            try:
                result = func_timeout.func_timeout(
                    WAIT_TIME,
                    lambda: try_generate_record(dataset.records(record_set=record_set.uuid))
                )

                if isinstance(result, Exception):
                    raise result  # re-raise actual error outside timeout

                results.append(f"Record set '{record_set.uuid}' passed validation.")

            except func_timeout.exceptions.FunctionTimedOut:
                error_message = f"Record set '{record_set.uuid}' generation took too long (>10 minutes)."
                return False, error_message, "warning"

            except Exception as e:
                error_details = traceback.format_exc()
                error_message = (
                    f"Record set '{record_set.uuid}' failed due to generation error:\n\n"
                    f"```text\n{str(e)}\n\n{error_details}```"
                )
                return False, error_message, "warning"
        
        return True, "\n".join(results), "pass"
    except Exception as e:
        error_details = traceback.format_exc()
        error_message = f"Unexpected error during records validation: {str(e)}\n\n{error_details}"
        return False, error_message, "error"
    
RAI_FIELDS = [
    "rai:dataLimitations",
    "rai:dataBiases",
    "rai:personalSensitiveInformation",
    "rai:dataUseCases",
    "rai:dataSocialImpact",
    "rai:hasSyntheticData",
]

RAI_GUIDELINES_URL = "https://neurips.cc/Conferences/2026/EvaluationsDatasetsHosting"

def validate_rai(json_data):
    """Check that all required Responsible AI metadata fields are present."""
    missing = [field for field in RAI_FIELDS if field not in json_data]
    if not missing:
        return True, "All required Responsible AI metadata fields are present."
    missing_list = "\n".join(f"- `{f}`" for f in missing)
    message = (
        f"The following required Responsible AI metadata fields are missing:\n{missing_list}\n\n"
        f"Please refer to the <a href='{RAI_GUIDELINES_URL}' target='_blank'>NeurIPS guidelines for instructions</a> on how to add them.\n\n You can use our <a href='https://huggingface.co/spaces/JoaquinVanschoren/croissant-rai-checker' target='_blank'>online RAI editing tool</a> to facilitate this process."
    )
    return False, message

def generate_validation_report(filename, json_data, results):
    """Generate a detailed validation report in markdown format."""
    report = []
    report.append("# CROISSANT VALIDATION REPORT")
    report.append("=" * 80)
    report.append("## VALIDATION RESULTS")
    report.append("-" * 80)
    report.append(f"Starting validation for file: {filename}")

    # Add validation results
    for result in results:
        if len(result) == 4:
            test_name, passed, message, status = result
        else:
            test_name, passed, message = result
            status = "pass" if passed else "error"

        report.append(f"### {test_name}")
        if status == "pass":
            report.append("✓")
        elif status == "warning":
            report.append("?")  # Question mark for warning
        else:
            report.append("✗")
        report.append(message.strip())  # Remove any trailing newlines

    # Add JSON-LD reference
    report.append("## JSON-LD REFERENCE")
    report.append("=" * 80)
    report.append("```json")
    report.append(json.dumps(json_data, indent=2))
    report.append("```")

    return "\n".join(report)