Alyafeai commited on
Commit
034f762
·
1 Parent(s): 5e5f8d3

fix issue with some parquet files have different formats

Browse files
Files changed (1) hide show
  1. backend/data_loader.py +19 -2
backend/data_loader.py CHANGED
@@ -414,11 +414,28 @@ def _as_dict(value: Any) -> Dict[str, Any]:
414
 
415
 
416
  def _py_scalar(value: Any) -> Any:
 
 
 
 
 
 
417
  if isinstance(value, np.generic):
418
  return value.item()
419
  return value
420
 
421
 
 
 
 
 
 
 
 
 
 
 
 
422
  def _extract_predicted_answer(model_response: Dict[str, Any], choices: List[Any]) -> Any:
423
  logprobs = model_response.get("logprobs")
424
  if logprobs is not None and choices:
@@ -500,7 +517,7 @@ def _read_detail_parquet(path: str, subtask: str) -> List[Dict[str, Any]]:
500
  or ""
501
  )
502
 
503
- rows.append({
504
  "subtask": subtask,
505
  "question_id": _py_scalar(doc.get("id")),
506
  "task_name": _py_scalar(doc.get("task_name")),
@@ -513,7 +530,7 @@ def _read_detail_parquet(path: str, subtask: str) -> List[Dict[str, Any]]:
513
  "is_correct": is_correct,
514
  "metric_name": metric_name,
515
  "metric": metric_value,
516
- })
517
 
518
  return rows
519
 
 
414
 
415
 
416
  def _py_scalar(value: Any) -> Any:
417
+ if isinstance(value, np.ndarray):
418
+ if value.ndim == 0:
419
+ return _py_scalar(value.item())
420
+ if value.size == 1:
421
+ return _py_scalar(value.reshape(-1)[0])
422
+ return [_py_scalar(v) for v in value.tolist()]
423
  if isinstance(value, np.generic):
424
  return value.item()
425
  return value
426
 
427
 
428
+ def _json_safe(value: Any) -> Any:
429
+ value = _py_scalar(value)
430
+ if isinstance(value, dict):
431
+ return {str(k): _json_safe(v) for k, v in value.items()}
432
+ if isinstance(value, list):
433
+ return [_json_safe(v) for v in value]
434
+ if isinstance(value, tuple):
435
+ return [_json_safe(v) for v in value]
436
+ return value
437
+
438
+
439
  def _extract_predicted_answer(model_response: Dict[str, Any], choices: List[Any]) -> Any:
440
  logprobs = model_response.get("logprobs")
441
  if logprobs is not None and choices:
 
517
  or ""
518
  )
519
 
520
+ rows.append(_json_safe({
521
  "subtask": subtask,
522
  "question_id": _py_scalar(doc.get("id")),
523
  "task_name": _py_scalar(doc.get("task_name")),
 
530
  "is_correct": is_correct,
531
  "metric_name": metric_name,
532
  "metric": metric_value,
533
+ }))
534
 
535
  return rows
536