Harshit Ghosh Copilot commited on
Commit
5105d0e
·
1 Parent(s): e3566c9

phase 3 ig

Browse files

Co-authored-by: Copilot <copilot@github.com>

Files changed (9) hide show
  1. app_new.py +290 -99
  2. auth_routes.py +3 -3
  3. auth_utils.py +2 -3
  4. models.py +11 -6
  5. run_interface.py +27 -10
  6. security.py +8 -3
  7. static/css/error_pages.css +7 -11
  8. tasks.py +273 -81
  9. templates/404.html +0 -25
app_new.py CHANGED
@@ -33,6 +33,7 @@ from dataclasses import dataclass
33
  from getpass import getpass
34
  from pathlib import Path
35
  from typing import Any
 
36
 
37
  try:
38
  from dotenv import load_dotenv
@@ -77,7 +78,7 @@ from werkzeug.middleware.proxy_fix import ProxyFix
77
  from flask_login import current_user, login_required
78
 
79
  # Import new security and auth modules
80
- from models import db, User, ScreeningReport, ScreeningUpload
81
  from auth_utils import init_auth, log_audit, get_client_ip
82
  from auth_routes import auth_bp
83
  from data_isolation import UserDataManager
@@ -125,6 +126,49 @@ HF_MODEL_REPO = os.environ.get("ICH_HF_MODEL_REPO", "").strip()
125
  HF_TOKEN = os.environ.get("ICH_HF_TOKEN", "").strip()
126
  LOCAL_MODE = _env_bool("ICH_LOCAL_MODE", True)
127
  SHOW_LOGS = _env_bool("ICH_SHOW_LOGS", False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  # ══════════════════════════════════════════════════════════════════════════
130
  # FLASK APP SETUP
@@ -316,6 +360,99 @@ def _ensure_model_loaded() -> bool:
316
  logger.error(f"Model loading failed: {e}", exc_info=True)
317
  return False
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  # ══════════════════════════════════════════════════════════════════════════
320
  # INFERENCE & BATCH PROCESSING
321
  # ══════════════════════════════════════════════════════════════════════════
@@ -331,7 +468,6 @@ def _run_inference_on_dcm(
331
 
332
  ri_mod = _MODEL["inference_mod"]
333
  image_id = dcm_path.stem
334
- user_reports_dir = UserDataManager().get_user_reports_dir(user_id)
335
 
336
  bbr.start()
337
 
@@ -345,55 +481,7 @@ def _run_inference_on_dcm(
345
  _MODEL["device"],
346
  _MODEL["temperature"],
347
  )
348
-
349
- user_reports_dir.mkdir(parents=True, exist_ok=True)
350
- report = ri_mod.build_report(
351
- image_id, inference, _MODEL["calib_cfg"],
352
- user_reports_dir, img_rgb, true_label=None,
353
- )
354
-
355
- pred = report.get("prediction", {})
356
- pred.setdefault("raw_probability", inference.get("raw_prob_any"))
357
- pred.setdefault("calibrated_probability", inference.get("cal_prob_any"))
358
- pred.setdefault("decision_threshold", pred.get("decision_threshold_any"))
359
- report["prediction"] = pred
360
-
361
- explainability = report.get("explainability", {}) if isinstance(report, dict) else {}
362
- gradcam_reference = (
363
- report.get("cloudinary_heatmap_url")
364
- or explainability.get("heatmap_path")
365
- or explainability.get("image_path")
366
- )
367
-
368
- report_path = user_reports_dir / f"{image_id}_report.json"
369
- with open(report_path, "w") as f:
370
- json.dump(report, f, indent=2)
371
-
372
- # Save to database
373
- user_data_dir = UserDataManager().get_user_data_dir(user_id)
374
-
375
- screening_report = ScreeningReport(
376
- user_id=user_id,
377
- upload_id=upload_id,
378
- image_id=image_id,
379
- screening_outcome=pred.get("screening_outcome"),
380
- raw_probability=pred.get("raw_probability"),
381
- calibrated_probability=pred.get("calibrated_probability"),
382
- confidence_band=pred.get("confidence_band"),
383
- decision_threshold=pred.get("decision_threshold"),
384
- triage_action=report.get("triage", {}).get("action"),
385
- urgency=report.get("triage", {}).get("urgency"),
386
- report_json_path=str(report_path.relative_to(user_data_dir)),
387
- gradcam_image_path=gradcam_reference,
388
- llm_summary=report.get("llm_summary"),
389
- report_payload=json.dumps(report, ensure_ascii=True),
390
- generated_at=datetime.datetime.utcnow(),
391
- )
392
- db.session.add(screening_report)
393
- db.session.commit()
394
-
395
- log_audit("inference_completed", user_id=user_id, resource_type="report",
396
- resource_id=screening_report.id, status="success")
397
 
398
  except Exception as e:
399
  db.session.rollback()
@@ -405,7 +493,7 @@ def _run_inference_on_dcm(
405
  bbr.stop()
406
 
407
  # Save trace
408
- ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
409
  base = f"{ts}_{image_id}"
410
  try:
411
  bbr.save_report(str(LOGS_DIR / f"{base}.txt"))
@@ -419,18 +507,25 @@ def _start_batch(dcm_paths: list[Path], user_id: int, temp_dir: str | None = Non
419
  """Trigger async batch processing via Celery."""
420
  batch_id = f"u{user_id}_{uuid.uuid4().hex[:12]}"
421
  dcm_paths_str = [str(p) for p in dcm_paths]
 
 
 
422
 
423
  # Send task to Celery worker
424
  try:
 
 
 
 
 
 
 
 
 
425
  task = celery_app.send_task(
426
  "tasks.process_dicom_batch",
427
- kwargs={
428
- "batch_id": batch_id,
429
- "dcm_paths": dcm_paths_str,
430
- "user_id": user_id,
431
- "temp_dir": temp_dir,
432
- },
433
- task_id=batch_id,
434
  )
435
  except Exception as exc:
436
  logger.error("Failed to enqueue Celery batch task", exc_info=True)
@@ -440,13 +535,18 @@ def _start_batch(dcm_paths: list[Path], user_id: int, temp_dir: str | None = Non
440
  return batch_id
441
 
442
 
 
 
 
 
443
  def _run_batch_sync(dcm_paths: list[Path], user_id: int, temp_dir: str | None = None) -> dict[str, Any]:
444
  """Fallback synchronous batch processing when Celery is unavailable."""
445
  total = len(dcm_paths)
446
  succeeded_ids: list[str] = []
447
  failed_ids: list[str] = []
448
- started_at = datetime.datetime.now().isoformat()
449
  sync_batch_id = f"sync_u{user_id}_{uuid.uuid4().hex[:12]}"
 
450
 
451
  log_audit(
452
  "batch_sync_started",
@@ -458,39 +558,106 @@ def _run_batch_sync(dcm_paths: list[Path], user_id: int, temp_dir: str | None =
458
  user_upload_dir = UserDataManager().get_user_upload_dir(user_id)
459
 
460
  try:
461
- for path in dcm_paths:
462
- image_id = path.stem
463
-
464
- upload_record = ScreeningUpload(
465
- user_id=user_id,
466
- file_name=path.name,
467
- original_filename=path.name,
468
- file_size=path.stat().st_size if path.exists() else None,
469
- file_path=str(path.relative_to(user_upload_dir)) if path.parent == user_upload_dir else str(path),
470
- processing_status="processing",
471
  )
472
- db.session.add(upload_record)
473
- db.session.commit()
474
-
475
- try:
476
- report, _ = _run_inference_on_dcm(path, user_id, upload_record.id)
477
- if report:
478
- upload_record.processing_status = "completed"
 
 
 
 
 
479
  db.session.commit()
480
- succeeded_ids.append(image_id)
481
- else:
482
- upload_record.processing_status = "failed"
483
- db.session.commit()
484
- failed_ids.append(image_id)
485
- except Exception as exc:
486
- logger.error(f"Sync batch failed {image_id} — {exc}", exc_info=True)
487
- db.session.rollback()
488
- upload_record.processing_status = "failed"
489
  try:
490
- db.session.commit()
491
- except Exception:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
  db.session.rollback()
493
- failed_ids.append(image_id)
 
 
 
 
 
494
  finally:
495
  if temp_dir and Path(temp_dir).exists():
496
  try:
@@ -520,7 +687,7 @@ def _run_batch_sync(dcm_paths: list[Path], user_id: int, temp_dir: str | None =
520
  "image_ids": list(succeeded_ids),
521
  "current_file": "",
522
  "started_at": started_at,
523
- "finished_at": datetime.datetime.now().isoformat(),
524
  "error": None,
525
  "temp_dir": temp_dir,
526
  }
@@ -576,13 +743,7 @@ class CaseRow:
576
 
577
  @property
578
  def date_display(self) -> str:
579
- if not self.generated_at:
580
- return "—"
581
- try:
582
- dt = datetime.datetime.fromisoformat(self.generated_at)
583
- return dt.strftime("%Y-%m-%d %H:%M")
584
- except (ValueError, TypeError):
585
- return self.generated_at[:16]
586
 
587
  @property
588
  def is_positive(self) -> bool:
@@ -1244,7 +1405,7 @@ def case_detail(image_id):
1244
  triage=report.triage_action or "N/A",
1245
  urgency=report.urgency or "N/A",
1246
  generated_at=_format_date(report.generated_at),
1247
- date_display=(report.generated_at.strftime("%Y-%m-%d %H:%M") if report.generated_at else "—"),
1248
  report_file=Path(report.report_json_path).name if report.report_json_path else None,
1249
  gradcam_url=gradcam_url,
1250
  true_label=report.true_label,
@@ -1295,10 +1456,15 @@ def logs_page():
1295
  if LOGS_DIR.exists():
1296
  for path in sorted(LOGS_DIR.iterdir(), reverse=True)[:50]: # Last 50 logs
1297
  if path.suffix in (".txt", ".json"):
 
 
 
 
 
1298
  log_files.append({
1299
  "name": path.name,
1300
  "size": round(path.stat().st_size / 1024, 1),
1301
- "modified": datetime.datetime.fromtimestamp(path.stat().st_mtime).isoformat(),
1302
  })
1303
 
1304
  return render_template("logs.html", logs=log_files)
@@ -1419,6 +1585,31 @@ def create_admin():
1419
  db.session.commit()
1420
  print(f"Admin user '{username}' created!")
1421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1422
  # ══════════════════════════════════════════════════════════════════════════
1423
  # MAIN
1424
  # ══════════════════════════════════════════════════════════════════════════
 
33
  from getpass import getpass
34
  from pathlib import Path
35
  from typing import Any
36
+ from zoneinfo import ZoneInfo
37
 
38
  try:
39
  from dotenv import load_dotenv
 
78
  from flask_login import current_user, login_required
79
 
80
  # Import new security and auth modules
81
+ from models import db, User, ScreeningReport, ScreeningUpload, AuditLog
82
  from auth_utils import init_auth, log_audit, get_client_ip
83
  from auth_routes import auth_bp
84
  from data_isolation import UserDataManager
 
126
  HF_TOKEN = os.environ.get("ICH_HF_TOKEN", "").strip()
127
  LOCAL_MODE = _env_bool("ICH_LOCAL_MODE", True)
128
  SHOW_LOGS = _env_bool("ICH_SHOW_LOGS", False)
129
+ GPU_BATCH_ENABLED = _env_bool("ICH_GPU_BATCH_INFERENCE", True)
130
+ GPU_BATCH_SIZE = _env_int("ICH_GPU_BATCH_SIZE", 2, minimum=1)
131
+ GPU_QUEUE_ENABLED = _env_bool("ICH_GPU_QUEUE_ENABLED", False)
132
+ GPU_QUEUE_NAME = os.environ.get("ICH_GPU_QUEUE_NAME", "gpu").strip() or "gpu"
133
+ CPU_QUEUE_NAME = os.environ.get("ICH_CPU_QUEUE_NAME", "cpu").strip() or "cpu"
134
+ IST = ZoneInfo("Asia/Kolkata")
135
+
136
+ def _now_ist() -> datetime.datetime:
137
+ return datetime.datetime.now(IST).replace(tzinfo=None)
138
+
139
+ def _as_ist(dt: datetime.datetime | None) -> datetime.datetime | None:
140
+ if dt is None:
141
+ return None
142
+ if dt.tzinfo is None:
143
+ dt = dt.replace(tzinfo=IST)
144
+ return dt.astimezone(IST)
145
+
146
+ def _format_dt_ist(dt: datetime.datetime | None, fmt: str = "%Y-%m-%d %H:%M") -> str:
147
+ local = _as_ist(dt)
148
+ return local.strftime(fmt) if local else "—"
149
+
150
+ def _format_iso_ist(value: str | None, fmt: str = "%Y-%m-%d %H:%M") -> str:
151
+ if not value:
152
+ return "—"
153
+ try:
154
+ parsed = datetime.datetime.fromisoformat(value)
155
+ except Exception:
156
+ return value[:16]
157
+ return _format_dt_ist(parsed, fmt)
158
+
159
+ def _to_ist_naive(dt: datetime.datetime | None) -> datetime.datetime | None:
160
+ if dt is None:
161
+ return None
162
+ if dt.tzinfo is None:
163
+ dt = dt.replace(tzinfo=datetime.timezone.utc)
164
+ return dt.astimezone(IST).replace(tzinfo=None)
165
+
166
+ def _cuda_available() -> bool:
167
+ try:
168
+ import torch
169
+ return torch.cuda.is_available()
170
+ except Exception:
171
+ return False
172
 
173
  # ══════════════════════════════════════════════════════════════════════════
174
  # FLASK APP SETUP
 
360
  logger.error(f"Model loading failed: {e}", exc_info=True)
361
  return False
362
 
363
+
364
+ def _gpu_batch_ready() -> bool:
365
+ if not GPU_BATCH_ENABLED:
366
+ return False
367
+ if not _ensure_model_loaded():
368
+ return False
369
+ return _MODEL.get("device") == "cuda"
370
+
371
+
372
+ def _infer_images_batch(dcm_paths: list[Path]) -> list[tuple[Any, dict[str, Any]]]:
373
+ if not _ensure_model_loaded():
374
+ raise RuntimeError("Model not loaded")
375
+
376
+ ri_mod = _MODEL["inference_mod"]
377
+ images = [ri_mod.dicom_to_rgb(str(path), size=ri_mod.IMG_SIZE) for path in dcm_paths]
378
+ inferences = ri_mod.infer_batch(
379
+ images,
380
+ _MODEL["model"],
381
+ _MODEL["grad_cam"],
382
+ _MODEL["transform"],
383
+ _MODEL["device"],
384
+ _MODEL["temperature"],
385
+ )
386
+ return list(zip(images, inferences, strict=False))
387
+
388
+
389
+ def _persist_inference_result(
390
+ image_id: str,
391
+ user_id: int,
392
+ upload_id: int,
393
+ img_rgb: Any,
394
+ inference: dict[str, Any],
395
+ ) -> dict[str, Any]:
396
+ ri_mod = _MODEL["inference_mod"]
397
+ user_reports_dir = UserDataManager().get_user_reports_dir(user_id)
398
+ user_reports_dir.mkdir(parents=True, exist_ok=True)
399
+
400
+ report = ri_mod.build_report(
401
+ image_id,
402
+ inference,
403
+ _MODEL["calib_cfg"],
404
+ user_reports_dir,
405
+ img_rgb,
406
+ true_label=None,
407
+ )
408
+
409
+ pred = report.get("prediction", {})
410
+ pred.setdefault("raw_probability", inference.get("raw_prob_any"))
411
+ pred.setdefault("calibrated_probability", inference.get("cal_prob_any"))
412
+ pred.setdefault("decision_threshold", pred.get("decision_threshold_any"))
413
+ report["prediction"] = pred
414
+
415
+ explainability = report.get("explainability", {}) if isinstance(report, dict) else {}
416
+ gradcam_reference = (
417
+ report.get("cloudinary_heatmap_url")
418
+ or explainability.get("heatmap_path")
419
+ or explainability.get("image_path")
420
+ )
421
+
422
+ report_path = user_reports_dir / f"{image_id}_report.json"
423
+ with open(report_path, "w") as f:
424
+ json.dump(report, f, separators=(",", ":"), ensure_ascii=True)
425
+
426
+ user_data_dir = UserDataManager().get_user_data_dir(user_id)
427
+ screening_report = ScreeningReport(
428
+ user_id=user_id,
429
+ upload_id=upload_id,
430
+ image_id=image_id,
431
+ screening_outcome=pred.get("screening_outcome"),
432
+ raw_probability=pred.get("raw_probability"),
433
+ calibrated_probability=pred.get("calibrated_probability"),
434
+ confidence_band=pred.get("confidence_band"),
435
+ decision_threshold=pred.get("decision_threshold"),
436
+ triage_action=report.get("triage", {}).get("action"),
437
+ urgency=report.get("triage", {}).get("urgency"),
438
+ report_json_path=str(report_path.relative_to(user_data_dir)),
439
+ gradcam_image_path=gradcam_reference,
440
+ llm_summary=report.get("llm_summary"),
441
+ report_payload=json.dumps(report, ensure_ascii=True, separators=(",", ":")),
442
+ generated_at=_now_ist(),
443
+ )
444
+ db.session.add(screening_report)
445
+ db.session.commit()
446
+
447
+ log_audit(
448
+ "inference_completed",
449
+ user_id=user_id,
450
+ resource_type="report",
451
+ resource_id=screening_report.id,
452
+ status="success",
453
+ )
454
+ return report
455
+
456
  # ══════════════════════════════════════════════════════════════════════════
457
  # INFERENCE & BATCH PROCESSING
458
  # ══════════════════════════════════════════════════════════════════════════
 
468
 
469
  ri_mod = _MODEL["inference_mod"]
470
  image_id = dcm_path.stem
 
471
 
472
  bbr.start()
473
 
 
481
  _MODEL["device"],
482
  _MODEL["temperature"],
483
  )
484
+ report = _persist_inference_result(image_id, user_id, upload_id, img_rgb, inference)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
  except Exception as e:
487
  db.session.rollback()
 
493
  bbr.stop()
494
 
495
  # Save trace
496
+ ts = _now_ist().strftime("%Y%m%d_%H%M%S")
497
  base = f"{ts}_{image_id}"
498
  try:
499
  bbr.save_report(str(LOGS_DIR / f"{base}.txt"))
 
507
  """Trigger async batch processing via Celery."""
508
  batch_id = f"u{user_id}_{uuid.uuid4().hex[:12]}"
509
  dcm_paths_str = [str(p) for p in dcm_paths]
510
+ queue = None
511
+ if GPU_QUEUE_ENABLED:
512
+ queue = GPU_QUEUE_NAME if _cuda_available() else CPU_QUEUE_NAME
513
 
514
  # Send task to Celery worker
515
  try:
516
+ task_kwargs = {
517
+ "batch_id": batch_id,
518
+ "dcm_paths": dcm_paths_str,
519
+ "user_id": user_id,
520
+ "temp_dir": temp_dir,
521
+ }
522
+ send_kwargs = {"task_id": batch_id}
523
+ if queue:
524
+ send_kwargs["queue"] = queue
525
  task = celery_app.send_task(
526
  "tasks.process_dicom_batch",
527
+ kwargs=task_kwargs,
528
+ **send_kwargs,
 
 
 
 
 
529
  )
530
  except Exception as exc:
531
  logger.error("Failed to enqueue Celery batch task", exc_info=True)
 
535
  return batch_id
536
 
537
 
538
+ def _iter_batches(items: list[Path], batch_size: int) -> list[list[Path]]:
539
+ return [items[i:i + batch_size] for i in range(0, len(items), batch_size)]
540
+
541
+
542
  def _run_batch_sync(dcm_paths: list[Path], user_id: int, temp_dir: str | None = None) -> dict[str, Any]:
543
  """Fallback synchronous batch processing when Celery is unavailable."""
544
  total = len(dcm_paths)
545
  succeeded_ids: list[str] = []
546
  failed_ids: list[str] = []
547
+ started_at = _now_ist().isoformat()
548
  sync_batch_id = f"sync_u{user_id}_{uuid.uuid4().hex[:12]}"
549
+ use_gpu_batch = _gpu_batch_ready() and total > 1
550
 
551
  log_audit(
552
  "batch_sync_started",
 
558
  user_upload_dir = UserDataManager().get_user_upload_dir(user_id)
559
 
560
  try:
561
+ if use_gpu_batch:
562
+ logger.info(
563
+ "GPU batch inference enabled (size=%s); per-image traces are skipped.",
564
+ GPU_BATCH_SIZE,
 
 
 
 
 
 
565
  )
566
+ for chunk in _iter_batches(dcm_paths, GPU_BATCH_SIZE):
567
+ upload_records: list[ScreeningUpload] = []
568
+ for path in chunk:
569
+ upload_record = ScreeningUpload(
570
+ user_id=user_id,
571
+ file_name=path.name,
572
+ original_filename=path.name,
573
+ file_size=path.stat().st_size if path.exists() else None,
574
+ file_path=str(path.relative_to(user_upload_dir)) if path.parent == user_upload_dir else str(path),
575
+ processing_status="processing",
576
+ )
577
+ db.session.add(upload_record)
578
  db.session.commit()
579
+ upload_records.append(upload_record)
580
+
 
 
 
 
 
 
 
581
  try:
582
+ batch_results = _infer_images_batch(chunk)
583
+ except Exception as exc:
584
+ logger.error("GPU batch inference failed — %s", exc, exc_info=True)
585
+ for path, upload_record in zip(chunk, upload_records, strict=False):
586
+ image_id = path.stem
587
+ db.session.rollback()
588
+ upload_record.processing_status = "failed"
589
+ try:
590
+ db.session.commit()
591
+ except Exception:
592
+ db.session.rollback()
593
+ failed_ids.append(image_id)
594
+ continue
595
+
596
+ for (path, upload_record), (img_rgb, inference) in zip(
597
+ zip(chunk, upload_records, strict=False),
598
+ batch_results,
599
+ strict=False,
600
+ ):
601
+ image_id = path.stem
602
+ try:
603
+ report = _persist_inference_result(
604
+ image_id,
605
+ user_id,
606
+ upload_record.id,
607
+ img_rgb,
608
+ inference,
609
+ )
610
+ if report:
611
+ upload_record.processing_status = "completed"
612
+ db.session.commit()
613
+ succeeded_ids.append(image_id)
614
+ else:
615
+ upload_record.processing_status = "failed"
616
+ db.session.commit()
617
+ failed_ids.append(image_id)
618
+ except Exception as exc:
619
+ logger.error(f"Sync batch failed {image_id} — {exc}", exc_info=True)
620
+ db.session.rollback()
621
+ upload_record.processing_status = "failed"
622
+ try:
623
+ db.session.commit()
624
+ except Exception:
625
+ db.session.rollback()
626
+ failed_ids.append(image_id)
627
+ else:
628
+ for path in dcm_paths:
629
+ image_id = path.stem
630
+
631
+ upload_record = ScreeningUpload(
632
+ user_id=user_id,
633
+ file_name=path.name,
634
+ original_filename=path.name,
635
+ file_size=path.stat().st_size if path.exists() else None,
636
+ file_path=str(path.relative_to(user_upload_dir)) if path.parent == user_upload_dir else str(path),
637
+ processing_status="processing",
638
+ )
639
+ db.session.add(upload_record)
640
+ db.session.commit()
641
+
642
+ try:
643
+ report, _ = _run_inference_on_dcm(path, user_id, upload_record.id)
644
+ if report:
645
+ upload_record.processing_status = "completed"
646
+ db.session.commit()
647
+ succeeded_ids.append(image_id)
648
+ else:
649
+ upload_record.processing_status = "failed"
650
+ db.session.commit()
651
+ failed_ids.append(image_id)
652
+ except Exception as exc:
653
+ logger.error(f"Sync batch failed {image_id} — {exc}", exc_info=True)
654
  db.session.rollback()
655
+ upload_record.processing_status = "failed"
656
+ try:
657
+ db.session.commit()
658
+ except Exception:
659
+ db.session.rollback()
660
+ failed_ids.append(image_id)
661
  finally:
662
  if temp_dir and Path(temp_dir).exists():
663
  try:
 
687
  "image_ids": list(succeeded_ids),
688
  "current_file": "",
689
  "started_at": started_at,
690
+ "finished_at": _now_ist().isoformat(),
691
  "error": None,
692
  "temp_dir": temp_dir,
693
  }
 
743
 
744
  @property
745
  def date_display(self) -> str:
746
+ return _format_iso_ist(self.generated_at)
 
 
 
 
 
 
747
 
748
  @property
749
  def is_positive(self) -> bool:
 
1405
  triage=report.triage_action or "N/A",
1406
  urgency=report.urgency or "N/A",
1407
  generated_at=_format_date(report.generated_at),
1408
+ date_display=_format_dt_ist(report.generated_at),
1409
  report_file=Path(report.report_json_path).name if report.report_json_path else None,
1410
  gradcam_url=gradcam_url,
1411
  true_label=report.true_label,
 
1456
  if LOGS_DIR.exists():
1457
  for path in sorted(LOGS_DIR.iterdir(), reverse=True)[:50]: # Last 50 logs
1458
  if path.suffix in (".txt", ".json"):
1459
+ modified = datetime.datetime.fromtimestamp(
1460
+ path.stat().st_mtime,
1461
+ tz=datetime.timezone.utc,
1462
+ )
1463
+ modified_local = _as_ist(modified)
1464
  log_files.append({
1465
  "name": path.name,
1466
  "size": round(path.stat().st_size / 1024, 1),
1467
+ "modified": modified_local.isoformat() if modified_local else "",
1468
  })
1469
 
1470
  return render_template("logs.html", logs=log_files)
 
1585
  db.session.commit()
1586
  print(f"Admin user '{username}' created!")
1587
 
1588
+ @app.cli.command()
1589
+ def migrate_utc_to_ist():
1590
+ """Convert existing UTC timestamps to IST (run once)."""
1591
+ with app.app_context():
1592
+ updates = 0
1593
+ models = {
1594
+ User: ["created_at", "updated_at"],
1595
+ ScreeningUpload: ["upload_timestamp"],
1596
+ ScreeningReport: ["generated_at", "created_at"],
1597
+ AuditLog: ["timestamp"],
1598
+ }
1599
+ for model, fields in models.items():
1600
+ for row in model.query.all():
1601
+ changed = False
1602
+ for field in fields:
1603
+ value = getattr(row, field, None)
1604
+ updated = _to_ist_naive(value)
1605
+ if updated and updated != value:
1606
+ setattr(row, field, updated)
1607
+ changed = True
1608
+ if changed:
1609
+ updates += 1
1610
+ db.session.commit()
1611
+ print(f"Migrated timestamps for {updates} rows.")
1612
+
1613
  # ══════════════════════════════════════════════════════════════════════════
1614
  # MAIN
1615
  # ══════════════════════════════════════════════════════════════════════════
auth_routes.py CHANGED
@@ -24,7 +24,7 @@ from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
24
  from sqlalchemy import func, or_
25
 
26
  from auth_utils import log_audit, validate_email, validate_password, validate_username
27
- from models import User, db
28
 
29
  logger = logging.getLogger(__name__)
30
  auth_bp = Blueprint('auth', __name__, url_prefix='/auth')
@@ -57,7 +57,7 @@ def _otp_payload_from_session() -> dict:
57
 
58
  def _store_otp(email: str, purpose: str, user_id: int | None = None) -> str:
59
  code = _generate_otp()
60
- expires_at = datetime.utcnow() + timedelta(minutes=10)
61
  session[OTP_SESSION_KEY] = {
62
  "email": email,
63
  "purpose": purpose,
@@ -98,7 +98,7 @@ def _validate_otp(submitted_code: str, expected_purpose: str) -> tuple[bool, str
98
  _clear_otp()
99
  return False, "OTP is invalid. Please request a new code.", None
100
 
101
- if datetime.utcnow() > expires_at:
102
  _clear_otp()
103
  return False, "OTP expired. Please request a new code.", None
104
 
 
24
  from sqlalchemy import func, or_
25
 
26
  from auth_utils import log_audit, validate_email, validate_password, validate_username
27
+ from models import User, db, now_ist
28
 
29
  logger = logging.getLogger(__name__)
30
  auth_bp = Blueprint('auth', __name__, url_prefix='/auth')
 
57
 
58
  def _store_otp(email: str, purpose: str, user_id: int | None = None) -> str:
59
  code = _generate_otp()
60
+ expires_at = now_ist() + timedelta(minutes=10)
61
  session[OTP_SESSION_KEY] = {
62
  "email": email,
63
  "purpose": purpose,
 
98
  _clear_otp()
99
  return False, "OTP is invalid. Please request a new code.", None
100
 
101
+ if now_ist() > expires_at:
102
  _clear_otp()
103
  return False, "OTP expired. Please request a new code.", None
104
 
auth_utils.py CHANGED
@@ -6,8 +6,7 @@ import logging
6
  from functools import wraps
7
  from flask import session, redirect, url_for, request, g, abort, has_request_context
8
  from flask_login import LoginManager, current_user
9
- from models import db, User, AuditLog
10
- from datetime import datetime
11
  from sqlalchemy.exc import SQLAlchemyError
12
 
13
  logger = logging.getLogger(__name__)
@@ -64,7 +63,7 @@ def log_audit(action, user_id=None, resource_type=None, resource_id=None,
64
  resource_id=resource_id,
65
  details=details,
66
  ip_address=get_client_ip(),
67
- timestamp=datetime.utcnow(),
68
  status=status
69
  )
70
  db.session.add(audit_entry)
 
6
  from functools import wraps
7
  from flask import session, redirect, url_for, request, g, abort, has_request_context
8
  from flask_login import LoginManager, current_user
9
+ from models import db, User, AuditLog, now_ist
 
10
  from sqlalchemy.exc import SQLAlchemyError
11
 
12
  logger = logging.getLogger(__name__)
 
63
  resource_id=resource_id,
64
  details=details,
65
  ip_address=get_client_ip(),
66
+ timestamp=now_ist(),
67
  status=status
68
  )
69
  db.session.add(audit_entry)
models.py CHANGED
@@ -3,12 +3,17 @@ Database models for ICH Screening Application with user authentication and priva
3
  """
4
  import os
5
  from datetime import datetime
 
6
  from flask_sqlalchemy import SQLAlchemy
7
  from flask_login import UserMixin
8
  from werkzeug.security import generate_password_hash, check_password_hash
9
  import secrets
10
 
11
  db = SQLAlchemy()
 
 
 
 
12
 
13
 
14
  class User(UserMixin, db.Model):
@@ -20,8 +25,8 @@ class User(UserMixin, db.Model):
20
  email = db.Column(db.String(120), unique=True, nullable=False, index=True)
21
  password_hash = db.Column(db.String(255), nullable=False)
22
  full_name = db.Column(db.String(120))
23
- created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False)
24
- updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
25
  is_active = db.Column(db.Boolean, default=True, nullable=False)
26
 
27
  # Relationships
@@ -50,7 +55,7 @@ class ScreeningUpload(db.Model):
50
  original_filename = db.Column(db.String(255), nullable=False)
51
  file_size = db.Column(db.Integer) # bytes
52
  file_path = db.Column(db.String(500), nullable=False) # Relative to user's upload dir
53
- upload_timestamp = db.Column(db.DateTime, default=datetime.utcnow, nullable=False, index=True)
54
  processing_status = db.Column(db.String(20), default='pending') # pending, processing, completed, failed
55
  processing_error = db.Column(db.Text) # Error message if failed
56
 
@@ -91,8 +96,8 @@ class ScreeningReport(db.Model):
91
  report_payload = db.Column(db.Text)
92
 
93
  # Generated timestamp
94
- generated_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False, index=True)
95
- created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False)
96
 
97
  def __repr__(self):
98
  return f'<ScreeningReport {self.id} - user {self.user_id} - {self.image_id}>'
@@ -109,7 +114,7 @@ class AuditLog(db.Model):
109
  resource_id = db.Column(db.String(255))
110
  details = db.Column(db.Text) # JSON or plain text with additional info
111
  ip_address = db.Column(db.String(45)) # IPv4 or IPv6
112
- timestamp = db.Column(db.DateTime, default=datetime.utcnow, nullable=False, index=True)
113
  status = db.Column(db.String(20), default='success') # success, failure
114
 
115
  def __repr__(self):
 
3
  """
4
  import os
5
  from datetime import datetime
6
+ from zoneinfo import ZoneInfo
7
  from flask_sqlalchemy import SQLAlchemy
8
  from flask_login import UserMixin
9
  from werkzeug.security import generate_password_hash, check_password_hash
10
  import secrets
11
 
12
  db = SQLAlchemy()
13
+ IST = ZoneInfo("Asia/Kolkata")
14
+
15
+ def now_ist() -> datetime:
16
+ return datetime.now(IST).replace(tzinfo=None)
17
 
18
 
19
  class User(UserMixin, db.Model):
 
25
  email = db.Column(db.String(120), unique=True, nullable=False, index=True)
26
  password_hash = db.Column(db.String(255), nullable=False)
27
  full_name = db.Column(db.String(120))
28
+ created_at = db.Column(db.DateTime, default=now_ist, nullable=False)
29
+ updated_at = db.Column(db.DateTime, default=now_ist, onupdate=now_ist)
30
  is_active = db.Column(db.Boolean, default=True, nullable=False)
31
 
32
  # Relationships
 
55
  original_filename = db.Column(db.String(255), nullable=False)
56
  file_size = db.Column(db.Integer) # bytes
57
  file_path = db.Column(db.String(500), nullable=False) # Relative to user's upload dir
58
+ upload_timestamp = db.Column(db.DateTime, default=now_ist, nullable=False, index=True)
59
  processing_status = db.Column(db.String(20), default='pending') # pending, processing, completed, failed
60
  processing_error = db.Column(db.Text) # Error message if failed
61
 
 
96
  report_payload = db.Column(db.Text)
97
 
98
  # Generated timestamp
99
+ generated_at = db.Column(db.DateTime, default=now_ist, nullable=False, index=True)
100
+ created_at = db.Column(db.DateTime, default=now_ist, nullable=False)
101
 
102
  def __repr__(self):
103
  return f'<ScreeningReport {self.id} - user {self.user_id} - {self.image_id}>'
 
114
  resource_id = db.Column(db.String(255))
115
  details = db.Column(db.Text) # JSON or plain text with additional info
116
  ip_address = db.Column(db.String(45)) # IPv4 or IPv6
117
+ timestamp = db.Column(db.DateTime, default=now_ist, nullable=False, index=True)
118
  status = db.Column(db.String(20), default='success') # success, failure
119
 
120
  def __repr__(self):
run_interface.py CHANGED
@@ -120,11 +120,25 @@ def infer_single(
120
  device: str,
121
  temperature: float,
122
  ) -> dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
123
  # Build 3ch tensor from the app's transform pipeline, then tile to 9ch
124
  # because the trained model expects 2.5D channels.
125
- t3 = transform(img_rgb).unsqueeze(0).to(device)
 
 
 
 
126
  t9 = torch.cat([t3, t3, t3], dim=1)
127
-
128
  if isinstance(model, list) and isinstance(grad_cam, list):
129
  fold_logits = []
130
  fold_cams = []
@@ -140,14 +154,17 @@ def infer_single(
140
  raw_probs = core.sigmoid_np(logits)
141
  cal_probs = core.sigmoid_np(logits / max(float(temperature), 1e-6))
142
 
143
- return {
144
- "raw_logits": logits,
145
- "raw_probs": raw_probs,
146
- "cal_probs": cal_probs,
147
- "raw_prob_any": float(raw_probs[0]),
148
- "cal_prob_any": float(cal_probs[0]),
149
- "cam": cam,
150
- }
 
 
 
151
 
152
 
153
  def generate_medical_summary(inference: dict[str, Any], calib_cfg: dict[str, Any], report: dict[str, Any]) -> str:
 
120
  device: str,
121
  temperature: float,
122
  ) -> dict[str, Any]:
123
+ return infer_batch([img_rgb], model, grad_cam, transform, device, temperature)[0]
124
+
125
+
126
+ def infer_batch(
127
+ images_rgb: list[np.ndarray],
128
+ model,
129
+ grad_cam: GradCAM,
130
+ transform,
131
+ device: str,
132
+ temperature: float,
133
+ ) -> list[dict[str, Any]]:
134
  # Build 3ch tensor from the app's transform pipeline, then tile to 9ch
135
  # because the trained model expects 2.5D channels.
136
+ if device == "cuda":
137
+ with torch.inference_mode():
138
+ t3 = torch.stack([transform(img) for img in images_rgb], dim=0).to(device)
139
+ else:
140
+ t3 = torch.stack([transform(img) for img in images_rgb], dim=0).to(device)
141
  t9 = torch.cat([t3, t3, t3], dim=1)
 
142
  if isinstance(model, list) and isinstance(grad_cam, list):
143
  fold_logits = []
144
  fold_cams = []
 
154
  raw_probs = core.sigmoid_np(logits)
155
  cal_probs = core.sigmoid_np(logits / max(float(temperature), 1e-6))
156
 
157
+ results = []
158
+ for idx in range(len(images_rgb)):
159
+ results.append({
160
+ "raw_logits": logits[idx],
161
+ "raw_probs": raw_probs[idx],
162
+ "cal_probs": cal_probs[idx],
163
+ "raw_prob_any": float(raw_probs[idx][0]),
164
+ "cal_prob_any": float(cal_probs[idx][0]),
165
+ "cam": cam[idx],
166
+ })
167
+ return results
168
 
169
 
170
  def generate_medical_summary(inference: dict[str, Any], calib_cfg: dict[str, Any], report: dict[str, Any]) -> str:
security.py CHANGED
@@ -5,8 +5,13 @@ import os
5
  import logging
6
  from flask import request
7
  from datetime import datetime, timedelta
 
8
 
9
  logger = logging.getLogger(__name__)
 
 
 
 
10
 
11
 
12
  def init_security(app):
@@ -157,7 +162,7 @@ def get_client_info() -> dict:
157
  'user_agent': request.headers.get('User-Agent', 'Unknown'),
158
  'endpoint': request.endpoint,
159
  'method': request.method,
160
- 'timestamp': datetime.utcnow().isoformat()
161
  }
162
 
163
 
@@ -171,7 +176,7 @@ class RateLimiter:
171
 
172
  def is_rate_limited(self, key: str) -> bool:
173
  """Check if a key has exceeded rate limit"""
174
- now = datetime.utcnow()
175
  window_start = now - timedelta(seconds=self.window_seconds)
176
 
177
  # Clean old entries
@@ -188,7 +193,7 @@ class RateLimiter:
188
 
189
  def record_request(self, key: str):
190
  """Record a request for rate limiting"""
191
- now = datetime.utcnow()
192
 
193
  if key not in self.requests:
194
  self.requests[key] = []
 
5
  import logging
6
  from flask import request
7
  from datetime import datetime, timedelta
8
+ from zoneinfo import ZoneInfo
9
 
10
  logger = logging.getLogger(__name__)
11
+ IST = ZoneInfo("Asia/Kolkata")
12
+
13
+ def now_ist() -> datetime:
14
+ return datetime.now(IST).replace(tzinfo=None)
15
 
16
 
17
  def init_security(app):
 
162
  'user_agent': request.headers.get('User-Agent', 'Unknown'),
163
  'endpoint': request.endpoint,
164
  'method': request.method,
165
+ 'timestamp': now_ist().isoformat()
166
  }
167
 
168
 
 
176
 
177
  def is_rate_limited(self, key: str) -> bool:
178
  """Check if a key has exceeded rate limit"""
179
+ now = now_ist()
180
  window_start = now - timedelta(seconds=self.window_seconds)
181
 
182
  # Clean old entries
 
193
 
194
  def record_request(self, key: str):
195
  """Record a request for rate limiting"""
196
+ now = now_ist()
197
 
198
  if key not in self.requests:
199
  self.requests[key] = []
static/css/error_pages.css CHANGED
@@ -75,6 +75,7 @@
75
  .error-code-wrap {
76
  position: relative;
77
  display: inline-block;
 
78
  }
79
  .error-scanline {
80
  position: absolute;
@@ -91,17 +92,6 @@
91
  to { transform:translateY(80px); opacity:0; }
92
  }
93
 
94
- /* SVG illustration */
95
- .error-illustration {
96
- margin: -10px 0 28px;
97
- position: relative; z-index: 1;
98
- animation: float-err 4s ease-in-out infinite;
99
- }
100
- @keyframes float-err {
101
- 0%,100%{ transform:translateY(0); }
102
- 50% { transform:translateY(-10px); }
103
- }
104
-
105
  /* text */
106
  .error-title {
107
  font-size: 1.6rem; font-weight: 800;
@@ -138,6 +128,12 @@
138
  transition: opacity .2s, transform .15s, box-shadow .2s;
139
  border: none; cursor: pointer;
140
  }
 
 
 
 
 
 
141
  .btn-err-primary:hover {
142
  opacity: .9; transform: translateY(-2px);
143
  box-shadow: 0 6px 26px rgba(110,168,254,.45);
 
75
  .error-code-wrap {
76
  position: relative;
77
  display: inline-block;
78
+ margin-bottom: 18px;
79
  }
80
  .error-scanline {
81
  position: absolute;
 
92
  to { transform:translateY(80px); opacity:0; }
93
  }
94
 
 
 
 
 
 
 
 
 
 
 
 
95
  /* text */
96
  .error-title {
97
  font-size: 1.6rem; font-weight: 800;
 
128
  transition: opacity .2s, transform .15s, box-shadow .2s;
129
  border: none; cursor: pointer;
130
  }
131
+ .btn-err-primary svg,
132
+ .btn-err-secondary svg {
133
+ display: block;
134
+ flex-shrink: 0;
135
+ transform: translateY(1px);
136
+ }
137
  .btn-err-primary:hover {
138
  opacity: .9; transform: translateY(-2px);
139
  box-shadow: 0 6px 26px rgba(110,168,254,.45);
tasks.py CHANGED
@@ -14,6 +14,7 @@ import sys
14
  import traceback
15
  from pathlib import Path
16
  from typing import Any
 
17
 
18
  # Ensure the app directory is in the Python path so imports work in worker processes
19
  APP_DIR = Path(__file__).parent.absolute()
@@ -29,6 +30,22 @@ except ImportError:
29
  from celery import Celery, current_task
30
 
31
  logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  # Extract Redis URL from environment
34
  REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
@@ -53,14 +70,27 @@ celery_app.conf.update(
53
  task_serializer="json",
54
  accept_content=["json"],
55
  result_serializer="json",
56
- timezone="UTC",
57
- enable_utc=True,
58
  task_track_started=True,
59
  task_time_limit=3600, # 1 hour hard limit
60
  task_soft_time_limit=3300, # 55 min soft limit
61
  result_expires=86400, # 24 hours
62
  )
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  @celery_app.task(bind=True, name="tasks.process_dicom_batch")
66
  def process_dicom_batch(
@@ -111,95 +141,257 @@ def process_dicom_batch(
111
  total = len(dcm_paths)
112
  succeeded_ids = []
113
  failed_ids = []
114
- started_at = datetime.datetime.now().isoformat()
115
 
116
  logger.info(f"Batch {batch_id} starting: {total} files for user {user_id}")
117
 
118
  try:
119
  with app.app_context():
120
- for i, path_str in enumerate(dcm_paths, 1):
121
- # Check if task was revoked (compat across Celery versions)
122
- request_ctx = current_task.request
123
- is_revoked = bool(getattr(request_ctx, "is_revoked", False)) or bool(
124
- getattr(request_ctx, "revoked", False)
125
- )
126
- if is_revoked:
127
- logger.info(f"Batch {batch_id} revoked, stopping")
128
- break
129
-
130
- path = Path(path_str)
131
- image_id = path.stem
132
-
133
- upload_record = ScreeningUpload(
134
- user_id=user_id,
135
- file_name=path.name,
136
- original_filename=path.name,
137
- file_size=path.stat().st_size if path.exists() else None,
138
- file_path=str(path),
139
- processing_status="processing",
140
  )
141
- db.session.add(upload_record)
142
- db.session.commit()
143
-
144
- # Update Celery task state with progress (matches _BATCHES format for frontend)
145
- self.update_state(
146
- state="PROGRESS",
147
- meta={
148
- "batch_id": batch_id,
149
- "user_id": user_id,
150
- "status": "running",
151
- "total": total,
152
- "processed": i - 1,
153
- "succeeded": len(succeeded_ids),
154
- "failed_ids": list(failed_ids),
155
- "image_ids": list(succeeded_ids),
156
- "current_file": image_id,
157
- "started_at": started_at,
158
- "finished_at": None,
159
- "error": None,
160
- "temp_dir": temp_dir,
161
- },
162
  )
 
 
 
 
 
163
 
164
- try:
165
- report, _ = _run_inference_on_dcm(path, user_id, upload_record.id)
166
- if report:
167
- upload_record.processing_status = "completed"
168
- db.session.commit()
169
- succeeded_ids.append(image_id)
170
- else:
171
- upload_record.processing_status = "failed"
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  db.session.commit()
173
- failed_ids.append(image_id)
174
- except Exception as e:
175
- logger.error(f"Batch {batch_id}: failed {image_id} — {e}")
176
- db.session.rollback()
177
- upload_record.processing_status = "failed"
178
  try:
179
- db.session.commit()
180
- except Exception:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  db.session.rollback()
182
- failed_ids.append(image_id)
183
-
184
- # Update after processing each file
185
- self.update_state(
186
- state="PROGRESS",
187
- meta={
188
- "batch_id": batch_id,
189
- "user_id": user_id,
190
- "status": "running",
191
- "total": total,
192
- "processed": i,
193
- "succeeded": len(succeeded_ids),
194
- "failed_ids": list(failed_ids),
195
- "image_ids": list(succeeded_ids),
196
- "current_file": "",
197
- "started_at": started_at,
198
- "finished_at": None,
199
- "error": None,
200
- "temp_dir": temp_dir,
201
- },
202
- )
 
 
 
 
 
203
 
204
  # Cleanup temporary directory if provided
205
  if temp_dir and Path(temp_dir).exists():
@@ -231,7 +423,7 @@ def process_dicom_batch(
231
  "image_ids": list(succeeded_ids),
232
  "current_file": "",
233
  "started_at": started_at,
234
- "finished_at": datetime.datetime.now().isoformat(),
235
  "error": None,
236
  "temp_dir": temp_dir,
237
  }
 
14
  import traceback
15
  from pathlib import Path
16
  from typing import Any
17
+ from zoneinfo import ZoneInfo
18
 
19
  # Ensure the app directory is in the Python path so imports work in worker processes
20
  APP_DIR = Path(__file__).parent.absolute()
 
30
  from celery import Celery, current_task
31
 
32
  logger = logging.getLogger(__name__)
33
+ IST = ZoneInfo("Asia/Kolkata")
34
+
35
+ def _now_ist() -> datetime.datetime:
36
+ return datetime.datetime.now(IST).replace(tzinfo=None)
37
+
38
+ def _env_int(name: str, default: int | None = None, *, minimum: int | None = None) -> int | None:
39
+ raw = os.environ.get(name)
40
+ if raw is None:
41
+ return default
42
+ try:
43
+ value = int(raw)
44
+ if minimum is not None and value < minimum:
45
+ return default
46
+ return value
47
+ except ValueError:
48
+ return default
49
 
50
  # Extract Redis URL from environment
51
  REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
 
70
  task_serializer="json",
71
  accept_content=["json"],
72
  result_serializer="json",
73
+ timezone="Asia/Kolkata",
74
+ enable_utc=False,
75
  task_track_started=True,
76
  task_time_limit=3600, # 1 hour hard limit
77
  task_soft_time_limit=3300, # 55 min soft limit
78
  result_expires=86400, # 24 hours
79
  )
80
 
81
+ extra_conf: dict[str, Any] = {}
82
+ worker_concurrency = _env_int("ICH_CELERY_CONCURRENCY", None, minimum=1)
83
+ worker_prefetch = _env_int("ICH_CELERY_PREFETCH_MULTIPLIER", None, minimum=1)
84
+ if worker_concurrency is not None:
85
+ extra_conf["worker_concurrency"] = worker_concurrency
86
+ if worker_prefetch is not None:
87
+ extra_conf["worker_prefetch_multiplier"] = worker_prefetch
88
+ if extra_conf:
89
+ celery_app.conf.update(**extra_conf)
90
+
91
+ def _iter_batches(items: list[str], batch_size: int) -> list[list[str]]:
92
+ return [items[i:i + batch_size] for i in range(0, len(items), batch_size)]
93
+
94
 
95
  @celery_app.task(bind=True, name="tasks.process_dicom_batch")
96
  def process_dicom_batch(
 
141
  total = len(dcm_paths)
142
  succeeded_ids = []
143
  failed_ids = []
144
+ started_at = _now_ist().isoformat()
145
 
146
  logger.info(f"Batch {batch_id} starting: {total} files for user {user_id}")
147
 
148
  try:
149
  with app.app_context():
150
+ use_gpu_batch = False
151
+ batch_size = 1
152
+ _infer_images_batch = None
153
+ _persist_inference_result = None
154
+ try:
155
+ from app_new import (
156
+ GPU_BATCH_SIZE,
157
+ _gpu_batch_ready,
158
+ _infer_images_batch,
159
+ _persist_inference_result,
 
 
 
 
 
 
 
 
 
 
160
  )
161
+ use_gpu_batch = _gpu_batch_ready() and total > 1
162
+ batch_size = max(1, GPU_BATCH_SIZE)
163
+ except Exception:
164
+ use_gpu_batch = False
165
+
166
+ if use_gpu_batch and _infer_images_batch and _persist_inference_result:
167
+ logger.info(
168
+ "GPU batch inference enabled (size=%s); per-image traces are skipped.",
169
+ batch_size,
 
 
 
 
 
 
 
 
 
 
 
 
170
  )
171
+ processed = 0
172
+ revoked = False
173
+ for chunk in _iter_batches(dcm_paths, batch_size):
174
+ if revoked:
175
+ break
176
 
177
+ paths = [Path(p) for p in chunk]
178
+ upload_records: list[ScreeningUpload] = []
179
+ for path in paths:
180
+ request_ctx = current_task.request
181
+ is_revoked = bool(getattr(request_ctx, "is_revoked", False)) or bool(
182
+ getattr(request_ctx, "revoked", False)
183
+ )
184
+ if is_revoked:
185
+ logger.info(f"Batch {batch_id} revoked, stopping")
186
+ revoked = True
187
+ break
188
+
189
+ upload_record = ScreeningUpload(
190
+ user_id=user_id,
191
+ file_name=path.name,
192
+ original_filename=path.name,
193
+ file_size=path.stat().st_size if path.exists() else None,
194
+ file_path=str(path),
195
+ processing_status="processing",
196
+ )
197
+ db.session.add(upload_record)
198
  db.session.commit()
199
+ upload_records.append(upload_record)
200
+
201
+ if revoked:
202
+ break
203
+
204
  try:
205
+ batch_results = _infer_images_batch(paths)
206
+ except Exception as exc:
207
+ logger.error(
208
+ f"Batch {batch_id}: GPU batch inference failed — {exc}",
209
+ exc_info=True,
210
+ )
211
+ for path, upload_record in zip(paths, upload_records, strict=False):
212
+ image_id = path.stem
213
+ db.session.rollback()
214
+ upload_record.processing_status = "failed"
215
+ try:
216
+ db.session.commit()
217
+ except Exception:
218
+ db.session.rollback()
219
+ failed_ids.append(image_id)
220
+ processed += 1
221
+ self.update_state(
222
+ state="PROGRESS",
223
+ meta={
224
+ "batch_id": batch_id,
225
+ "user_id": user_id,
226
+ "status": "running",
227
+ "total": total,
228
+ "processed": processed,
229
+ "succeeded": len(succeeded_ids),
230
+ "failed_ids": list(failed_ids),
231
+ "image_ids": list(succeeded_ids),
232
+ "current_file": "",
233
+ "started_at": started_at,
234
+ "finished_at": None,
235
+ "error": None,
236
+ "temp_dir": temp_dir,
237
+ },
238
+ )
239
+ continue
240
+
241
+ for (path, upload_record), (img_rgb, inference) in zip(
242
+ zip(paths, upload_records, strict=False),
243
+ batch_results,
244
+ strict=False,
245
+ ):
246
+ image_id = path.stem
247
+ self.update_state(
248
+ state="PROGRESS",
249
+ meta={
250
+ "batch_id": batch_id,
251
+ "user_id": user_id,
252
+ "status": "running",
253
+ "total": total,
254
+ "processed": processed,
255
+ "succeeded": len(succeeded_ids),
256
+ "failed_ids": list(failed_ids),
257
+ "image_ids": list(succeeded_ids),
258
+ "current_file": image_id,
259
+ "started_at": started_at,
260
+ "finished_at": None,
261
+ "error": None,
262
+ "temp_dir": temp_dir,
263
+ },
264
+ )
265
+
266
+ try:
267
+ report = _persist_inference_result(
268
+ image_id,
269
+ user_id,
270
+ upload_record.id,
271
+ img_rgb,
272
+ inference,
273
+ )
274
+ if report:
275
+ upload_record.processing_status = "completed"
276
+ db.session.commit()
277
+ succeeded_ids.append(image_id)
278
+ else:
279
+ upload_record.processing_status = "failed"
280
+ db.session.commit()
281
+ failed_ids.append(image_id)
282
+ except Exception as exc:
283
+ logger.error(f"Batch {batch_id}: failed {image_id} — {exc}")
284
+ db.session.rollback()
285
+ upload_record.processing_status = "failed"
286
+ try:
287
+ db.session.commit()
288
+ except Exception:
289
+ db.session.rollback()
290
+ failed_ids.append(image_id)
291
+
292
+ processed += 1
293
+ self.update_state(
294
+ state="PROGRESS",
295
+ meta={
296
+ "batch_id": batch_id,
297
+ "user_id": user_id,
298
+ "status": "running",
299
+ "total": total,
300
+ "processed": processed,
301
+ "succeeded": len(succeeded_ids),
302
+ "failed_ids": list(failed_ids),
303
+ "image_ids": list(succeeded_ids),
304
+ "current_file": "",
305
+ "started_at": started_at,
306
+ "finished_at": None,
307
+ "error": None,
308
+ "temp_dir": temp_dir,
309
+ },
310
+ )
311
+ else:
312
+ for i, path_str in enumerate(dcm_paths, 1):
313
+ # Check if task was revoked (compat across Celery versions)
314
+ request_ctx = current_task.request
315
+ is_revoked = bool(getattr(request_ctx, "is_revoked", False)) or bool(
316
+ getattr(request_ctx, "revoked", False)
317
+ )
318
+ if is_revoked:
319
+ logger.info(f"Batch {batch_id} revoked, stopping")
320
+ break
321
+
322
+ path = Path(path_str)
323
+ image_id = path.stem
324
+
325
+ upload_record = ScreeningUpload(
326
+ user_id=user_id,
327
+ file_name=path.name,
328
+ original_filename=path.name,
329
+ file_size=path.stat().st_size if path.exists() else None,
330
+ file_path=str(path),
331
+ processing_status="processing",
332
+ )
333
+ db.session.add(upload_record)
334
+ db.session.commit()
335
+
336
+ # Update Celery task state with progress (matches _BATCHES format for frontend)
337
+ self.update_state(
338
+ state="PROGRESS",
339
+ meta={
340
+ "batch_id": batch_id,
341
+ "user_id": user_id,
342
+ "status": "running",
343
+ "total": total,
344
+ "processed": i - 1,
345
+ "succeeded": len(succeeded_ids),
346
+ "failed_ids": list(failed_ids),
347
+ "image_ids": list(succeeded_ids),
348
+ "current_file": image_id,
349
+ "started_at": started_at,
350
+ "finished_at": None,
351
+ "error": None,
352
+ "temp_dir": temp_dir,
353
+ },
354
+ )
355
+
356
+ try:
357
+ report, _ = _run_inference_on_dcm(path, user_id, upload_record.id)
358
+ if report:
359
+ upload_record.processing_status = "completed"
360
+ db.session.commit()
361
+ succeeded_ids.append(image_id)
362
+ else:
363
+ upload_record.processing_status = "failed"
364
+ db.session.commit()
365
+ failed_ids.append(image_id)
366
+ except Exception as e:
367
+ logger.error(f"Batch {batch_id}: failed {image_id} — {e}")
368
  db.session.rollback()
369
+ upload_record.processing_status = "failed"
370
+ try:
371
+ db.session.commit()
372
+ except Exception:
373
+ db.session.rollback()
374
+ failed_ids.append(image_id)
375
+
376
+ # Update after processing each file
377
+ self.update_state(
378
+ state="PROGRESS",
379
+ meta={
380
+ "batch_id": batch_id,
381
+ "user_id": user_id,
382
+ "status": "running",
383
+ "total": total,
384
+ "processed": i,
385
+ "succeeded": len(succeeded_ids),
386
+ "failed_ids": list(failed_ids),
387
+ "image_ids": list(succeeded_ids),
388
+ "current_file": "",
389
+ "started_at": started_at,
390
+ "finished_at": None,
391
+ "error": None,
392
+ "temp_dir": temp_dir,
393
+ },
394
+ )
395
 
396
  # Cleanup temporary directory if provided
397
  if temp_dir and Path(temp_dir).exists():
 
423
  "image_ids": list(succeeded_ids),
424
  "current_file": "",
425
  "started_at": started_at,
426
+ "finished_at": _now_ist().isoformat(),
427
  "error": None,
428
  "temp_dir": temp_dir,
429
  }
templates/404.html CHANGED
@@ -28,31 +28,6 @@
28
  <div class="error-scanline"></div>
29
  </div>
30
 
31
- <!-- Inline SVG illustration (floating brain scan) -->
32
- <div class="error-illustration">
33
- <svg width="160" height="110" viewBox="0 0 160 110" fill="none" xmlns="http://www.w3.org/2000/svg">
34
- <!-- CT scan ring -->
35
- <ellipse cx="80" cy="55" rx="72" ry="48" stroke="#243356" stroke-width="1.5"/>
36
- <ellipse cx="80" cy="55" rx="56" ry="36" stroke="#1e3060" stroke-width="1"/>
37
- <!-- gantry frame -->
38
- <rect x="8" y="18" width="12" height="74" rx="4" fill="#162244" stroke="#243356" stroke-width="1"/>
39
- <rect x="140" y="18" width="12" height="74" rx="4" fill="#162244" stroke="#243356" stroke-width="1"/>
40
- <!-- scan table -->
41
- <rect x="28" y="50" width="104" height="10" rx="4" fill="#111c33" stroke="#243356" stroke-width="1"/>
42
- <!-- question mark inside -->
43
- <text x="80" y="63" text-anchor="middle" font-family="Inter,sans-serif" font-size="28" font-weight="900"
44
- fill="url(#qgrad)" opacity=".9">?</text>
45
- <!-- sweep arc -->
46
- <path d="M80 7 A48 48 0 0 1 128 55" stroke="#6ea8fe" stroke-width="1.5" stroke-dasharray="6 4" opacity=".4"/>
47
- <defs>
48
- <linearGradient id="qgrad" x1="0" y1="0" x2="1" y2="1">
49
- <stop offset="0%" stop-color="#6ea8fe"/>
50
- <stop offset="100%" stop-color="#a78bfa"/>
51
- </linearGradient>
52
- </defs>
53
- </svg>
54
- </div>
55
-
56
  <h1 class="error-title">Page Not Found</h1>
57
  <p class="error-desc">
58
  We couldn't find the page you were looking for. It may have been moved, deleted,
 
28
  <div class="error-scanline"></div>
29
  </div>
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  <h1 class="error-title">Page Not Found</h1>
32
  <p class="error-desc">
33
  We couldn't find the page you were looking for. It may have been moved, deleted,