Spaces:
Sleeping
Sleeping
Scalability improvement: HF login, dynamic test updates, improved warnings for downloading issues
b81d3dd | import mlcroissant._src.operation_graph.operations.download as dl_mod | |
| import requests as _requests_mod | |
| import requests | |
| import os | |
| _SERVER_HF_TOKEN = os.environ.get("HF_TOKEN") | |
| print("[DEBUG] HF_TOKEN is", "set" if _SERVER_HF_TOKEN else "missing") | |
| # _active_token holds the HF token to use for the current validation request. | |
| # It defaults to the server-level HF_TOKEN but can be overridden per-request | |
| # via set_active_token() so that logged-in users' own tokens are used instead. | |
| _active_token: dict = {"token": _SERVER_HF_TOKEN} | |
| def set_active_token(token: str | None) -> None: | |
| """Set the HF token to use for the current validation request.""" | |
| _active_token["token"] = token if token else _SERVER_HF_TOKEN | |
| def clear_active_token() -> None: | |
| """Reset the HF token back to the server-level default.""" | |
| _active_token["token"] = _SERVER_HF_TOKEN | |
| # Patch requests.Session.send to fail immediately on HTTP 429 instead of | |
| # letting mlcroissant / fsspec / huggingface_hub retry silently for minutes. | |
| _orig_session_send = _requests_mod.Session.send | |
| def _rate_limit_aware_send(self, request, **kwargs): | |
| response = _orig_session_send(self, request, **kwargs) | |
| if response.status_code == 429: | |
| retry_after = response.headers.get("Retry-After", "unknown") | |
| raise _requests_mod.exceptions.HTTPError( | |
| f"HTTP 429 Too Many Requests for {request.url}. " | |
| f"Retry-After: {retry_after}s. " | |
| "You are being rate limited. Log in with your Hugging Face account to avoid this.", | |
| response=response, | |
| ) | |
| return response | |
| _requests_mod.Session.send = _rate_limit_aware_send | |
| # Only send HF credentials when downloading from huggingface.co. | |
| # The default get_basic_auth_from_env() applies auth to ALL URLs, which | |
| # causes non-HF hosts (e.g. OpenML) to return 400 Bad Request. | |
| _orig_download_from_http = dl_mod.Download._download_from_http | |
| def _hf_aware_download(self, filepath): | |
| url = self.node.content_url or "" | |
| token = _active_token["token"] | |
| if token and "huggingface.co" in url: | |
| os.environ["CROISSANT_BASIC_AUTH_USERNAME"] = "hf_user" | |
| os.environ["CROISSANT_BASIC_AUTH_PASSWORD"] = token | |
| else: | |
| os.environ.pop("CROISSANT_BASIC_AUTH_USERNAME", None) | |
| os.environ.pop("CROISSANT_BASIC_AUTH_PASSWORD", None) | |
| return _orig_download_from_http(self, filepath) | |
| dl_mod.Download._download_from_http = _hf_aware_download | |
| import logging | |
| import mlcroissant as mlc | |
| import func_timeout | |
| import json | |
| import traceback | |
| # Suppress noisy mlcroissant pattern-matching warnings | |
| logging.getLogger("root").addFilter( | |
| lambda r: "Could not match" not in r.getMessage() | |
| ) | |
| logging.getLogger().addFilter( | |
| lambda r: "Could not match" not in r.getMessage() | |
| ) | |
| WAIT_TIME = 10 * 60 # seconds | |
| def validate_json(file_path): | |
| """Validate that the file is proper JSON.""" | |
| try: | |
| with open(file_path, 'r') as f: | |
| json_data = json.load(f) | |
| return True, "The file is valid JSON.", json_data | |
| except json.JSONDecodeError as e: | |
| error_message = f"Invalid JSON format: {str(e)}" | |
| return False, error_message, None | |
| except Exception as e: | |
| error_message = f"Error reading file: {str(e)}" | |
| return False, error_message, None | |
| REQUIRED_SCHEMA_FIELDS = ["license"] | |
| def validate_croissant(json_data): | |
| """Validate that the JSON follows Croissant schema.""" | |
| try: | |
| dataset = mlc.Dataset(jsonld=json_data) | |
| missing = [f for f in REQUIRED_SCHEMA_FIELDS if not json_data.get(f)] | |
| if missing: | |
| return True, ( | |
| "The `license` field is missing. This is required for NeurIPS dataset submissions. " | |
| "Please add a `license` field to your Croissant file with the name or URL of the licence governing your dataset. " | |
| "Where possible, use <a href='https://www.kaggle.com/discussions/getting-started/116476' target='_blank'>open licenses</a> that " | |
| "allow reuse and reproducibility. However, when the dataset contains sensitive data or stricter licensing is unavoidable, " | |
| "please select an appropriate license that is as open as possible given the constraints. " | |
| "You can use our <a href='https://huggingface.co/spaces/JoaquinVanschoren/croissant-rai-checker' target='_blank'>online RAI editing tool</a> to fill this information." | |
| ), "error" | |
| return True, "The dataset passes Croissant validation.", "pass" | |
| except mlc.ValidationError as e: | |
| error_details = traceback.format_exc() | |
| error_message = f"Validation failed: {str(e)}\n\n{error_details}" | |
| return False, error_message, "error" | |
| except Exception as e: | |
| error_details = traceback.format_exc() | |
| error_message = f"Unexpected error during validation: {str(e)}\n\n{error_details}" | |
| return False, error_message, "error" | |
| def try_generate_record(record_collection): | |
| try: | |
| for i, record in enumerate(record_collection): | |
| if i == 0: | |
| break | |
| return "success" | |
| except Exception as e: | |
| return e | |
| def validate_records(json_data): | |
| """Validate that records can be generated within the time limit.""" | |
| try: | |
| dataset = mlc.Dataset(jsonld=json_data) | |
| record_sets = dataset.metadata.record_sets | |
| if not record_sets: | |
| return True, "No record sets found to validate.", "pass" | |
| results = [] | |
| for record_set in record_sets: | |
| try: | |
| result = func_timeout.func_timeout( | |
| WAIT_TIME, | |
| lambda: try_generate_record(dataset.records(record_set=record_set.uuid)) | |
| ) | |
| if isinstance(result, Exception): | |
| raise result # re-raise actual error outside timeout | |
| results.append(f"Record set '{record_set.uuid}' passed validation.") | |
| except func_timeout.exceptions.FunctionTimedOut: | |
| error_message = f"Record set '{record_set.uuid}' generation took too long (>10 minutes)." | |
| return False, error_message, "warning" | |
| except Exception as e: | |
| error_details = traceback.format_exc() | |
| error_message = ( | |
| f"Record set '{record_set.uuid}' failed due to generation error:\n\n" | |
| f"```text\n{str(e)}\n\n{error_details}```" | |
| ) | |
| return False, error_message, "warning" | |
| return True, "\n".join(results), "pass" | |
| except Exception as e: | |
| error_details = traceback.format_exc() | |
| error_message = f"Unexpected error during records validation: {str(e)}\n\n{error_details}" | |
| return False, error_message, "error" | |
| RAI_FIELDS = [ | |
| "rai:dataLimitations", | |
| "rai:dataBiases", | |
| "rai:personalSensitiveInformation", | |
| "rai:dataUseCases", | |
| "rai:dataSocialImpact", | |
| "rai:hasSyntheticData", | |
| ] | |
| RAI_GUIDELINES_URL = "https://neurips.cc/Conferences/2026/EvaluationsDatasetsHosting" | |
| def validate_rai(json_data): | |
| """Check that all required Responsible AI metadata fields are present.""" | |
| missing = [field for field in RAI_FIELDS if field not in json_data] | |
| if not missing: | |
| return True, "All required Responsible AI metadata fields are present." | |
| missing_list = "\n".join(f"- `{f}`" for f in missing) | |
| message = ( | |
| f"The following required Responsible AI metadata fields are missing:\n{missing_list}\n\n" | |
| f"Please refer to the <a href='{RAI_GUIDELINES_URL}' target='_blank'>NeurIPS guidelines for instructions</a> on how to add them.\n\n You can use our <a href='https://huggingface.co/spaces/JoaquinVanschoren/croissant-rai-checker' target='_blank'>online RAI editing tool</a> to facilitate this process." | |
| ) | |
| return False, message | |
| def generate_validation_report(filename, json_data, results): | |
| """Generate a detailed validation report in markdown format.""" | |
| report = [] | |
| report.append("# CROISSANT VALIDATION REPORT") | |
| report.append("=" * 80) | |
| report.append("## VALIDATION RESULTS") | |
| report.append("-" * 80) | |
| report.append(f"Starting validation for file: {filename}") | |
| # Add validation results | |
| for result in results: | |
| if len(result) == 4: | |
| test_name, passed, message, status = result | |
| else: | |
| test_name, passed, message = result | |
| status = "pass" if passed else "error" | |
| report.append(f"### {test_name}") | |
| if status == "pass": | |
| report.append("✓") | |
| elif status == "warning": | |
| report.append("?") # Question mark for warning | |
| else: | |
| report.append("✗") | |
| report.append(message.strip()) # Remove any trailing newlines | |
| # Add JSON-LD reference | |
| report.append("## JSON-LD REFERENCE") | |
| report.append("=" * 80) | |
| report.append("```json") | |
| report.append(json.dumps(json_data, indent=2)) | |
| report.append("```") | |
| return "\n".join(report) |