croissant-checker / validation.py
JoaquinVanschoren's picture
Scalability improvement: HF login, dynamic test updates, improved warnings for downloading issues
b81d3dd
import mlcroissant._src.operation_graph.operations.download as dl_mod
import requests as _requests_mod
import requests
import os
_SERVER_HF_TOKEN = os.environ.get("HF_TOKEN")
print("[DEBUG] HF_TOKEN is", "set" if _SERVER_HF_TOKEN else "missing")
# _active_token holds the HF token to use for the current validation request.
# It defaults to the server-level HF_TOKEN but can be overridden per-request
# via set_active_token() so that logged-in users' own tokens are used instead.
_active_token: dict = {"token": _SERVER_HF_TOKEN}
def set_active_token(token: str | None) -> None:
"""Set the HF token to use for the current validation request."""
_active_token["token"] = token if token else _SERVER_HF_TOKEN
def clear_active_token() -> None:
"""Reset the HF token back to the server-level default."""
_active_token["token"] = _SERVER_HF_TOKEN
# Patch requests.Session.send to fail immediately on HTTP 429 instead of
# letting mlcroissant / fsspec / huggingface_hub retry silently for minutes.
_orig_session_send = _requests_mod.Session.send
def _rate_limit_aware_send(self, request, **kwargs):
response = _orig_session_send(self, request, **kwargs)
if response.status_code == 429:
retry_after = response.headers.get("Retry-After", "unknown")
raise _requests_mod.exceptions.HTTPError(
f"HTTP 429 Too Many Requests for {request.url}. "
f"Retry-After: {retry_after}s. "
"You are being rate limited. Log in with your Hugging Face account to avoid this.",
response=response,
)
return response
_requests_mod.Session.send = _rate_limit_aware_send
# Only send HF credentials when downloading from huggingface.co.
# The default get_basic_auth_from_env() applies auth to ALL URLs, which
# causes non-HF hosts (e.g. OpenML) to return 400 Bad Request.
_orig_download_from_http = dl_mod.Download._download_from_http
def _hf_aware_download(self, filepath):
url = self.node.content_url or ""
token = _active_token["token"]
if token and "huggingface.co" in url:
os.environ["CROISSANT_BASIC_AUTH_USERNAME"] = "hf_user"
os.environ["CROISSANT_BASIC_AUTH_PASSWORD"] = token
else:
os.environ.pop("CROISSANT_BASIC_AUTH_USERNAME", None)
os.environ.pop("CROISSANT_BASIC_AUTH_PASSWORD", None)
return _orig_download_from_http(self, filepath)
dl_mod.Download._download_from_http = _hf_aware_download
import logging
import mlcroissant as mlc
import func_timeout
import json
import traceback
# Suppress noisy mlcroissant pattern-matching warnings
logging.getLogger("root").addFilter(
lambda r: "Could not match" not in r.getMessage()
)
logging.getLogger().addFilter(
lambda r: "Could not match" not in r.getMessage()
)
WAIT_TIME = 10 * 60 # seconds
def validate_json(file_path):
"""Validate that the file is proper JSON."""
try:
with open(file_path, 'r') as f:
json_data = json.load(f)
return True, "The file is valid JSON.", json_data
except json.JSONDecodeError as e:
error_message = f"Invalid JSON format: {str(e)}"
return False, error_message, None
except Exception as e:
error_message = f"Error reading file: {str(e)}"
return False, error_message, None
REQUIRED_SCHEMA_FIELDS = ["license"]
def validate_croissant(json_data):
"""Validate that the JSON follows Croissant schema."""
try:
dataset = mlc.Dataset(jsonld=json_data)
missing = [f for f in REQUIRED_SCHEMA_FIELDS if not json_data.get(f)]
if missing:
return True, (
"The `license` field is missing. This is required for NeurIPS dataset submissions. "
"Please add a `license` field to your Croissant file with the name or URL of the licence governing your dataset. "
"Where possible, use <a href='https://www.kaggle.com/discussions/getting-started/116476' target='_blank'>open licenses</a> that "
"allow reuse and reproducibility. However, when the dataset contains sensitive data or stricter licensing is unavoidable, "
"please select an appropriate license that is as open as possible given the constraints. "
"You can use our <a href='https://huggingface.co/spaces/JoaquinVanschoren/croissant-rai-checker' target='_blank'>online RAI editing tool</a> to fill this information."
), "error"
return True, "The dataset passes Croissant validation.", "pass"
except mlc.ValidationError as e:
error_details = traceback.format_exc()
error_message = f"Validation failed: {str(e)}\n\n{error_details}"
return False, error_message, "error"
except Exception as e:
error_details = traceback.format_exc()
error_message = f"Unexpected error during validation: {str(e)}\n\n{error_details}"
return False, error_message, "error"
def try_generate_record(record_collection):
try:
for i, record in enumerate(record_collection):
if i == 0:
break
return "success"
except Exception as e:
return e
def validate_records(json_data):
"""Validate that records can be generated within the time limit."""
try:
dataset = mlc.Dataset(jsonld=json_data)
record_sets = dataset.metadata.record_sets
if not record_sets:
return True, "No record sets found to validate.", "pass"
results = []
for record_set in record_sets:
try:
result = func_timeout.func_timeout(
WAIT_TIME,
lambda: try_generate_record(dataset.records(record_set=record_set.uuid))
)
if isinstance(result, Exception):
raise result # re-raise actual error outside timeout
results.append(f"Record set '{record_set.uuid}' passed validation.")
except func_timeout.exceptions.FunctionTimedOut:
error_message = f"Record set '{record_set.uuid}' generation took too long (>10 minutes)."
return False, error_message, "warning"
except Exception as e:
error_details = traceback.format_exc()
error_message = (
f"Record set '{record_set.uuid}' failed due to generation error:\n\n"
f"```text\n{str(e)}\n\n{error_details}```"
)
return False, error_message, "warning"
return True, "\n".join(results), "pass"
except Exception as e:
error_details = traceback.format_exc()
error_message = f"Unexpected error during records validation: {str(e)}\n\n{error_details}"
return False, error_message, "error"
RAI_FIELDS = [
"rai:dataLimitations",
"rai:dataBiases",
"rai:personalSensitiveInformation",
"rai:dataUseCases",
"rai:dataSocialImpact",
"rai:hasSyntheticData",
]
RAI_GUIDELINES_URL = "https://neurips.cc/Conferences/2026/EvaluationsDatasetsHosting"
def validate_rai(json_data):
"""Check that all required Responsible AI metadata fields are present."""
missing = [field for field in RAI_FIELDS if field not in json_data]
if not missing:
return True, "All required Responsible AI metadata fields are present."
missing_list = "\n".join(f"- `{f}`" for f in missing)
message = (
f"The following required Responsible AI metadata fields are missing:\n{missing_list}\n\n"
f"Please refer to the <a href='{RAI_GUIDELINES_URL}' target='_blank'>NeurIPS guidelines for instructions</a> on how to add them.\n\n You can use our <a href='https://huggingface.co/spaces/JoaquinVanschoren/croissant-rai-checker' target='_blank'>online RAI editing tool</a> to facilitate this process."
)
return False, message
def generate_validation_report(filename, json_data, results):
"""Generate a detailed validation report in markdown format."""
report = []
report.append("# CROISSANT VALIDATION REPORT")
report.append("=" * 80)
report.append("## VALIDATION RESULTS")
report.append("-" * 80)
report.append(f"Starting validation for file: {filename}")
# Add validation results
for result in results:
if len(result) == 4:
test_name, passed, message, status = result
else:
test_name, passed, message = result
status = "pass" if passed else "error"
report.append(f"### {test_name}")
if status == "pass":
report.append("✓")
elif status == "warning":
report.append("?") # Question mark for warning
else:
report.append("✗")
report.append(message.strip()) # Remove any trailing newlines
# Add JSON-LD reference
report.append("## JSON-LD REFERENCE")
report.append("=" * 80)
report.append("```json")
report.append(json.dumps(json_data, indent=2))
report.append("```")
return "\n".join(report)