| """ |
| Celery task workers for async inference and batch processing. |
| Handles long-running DICOM processing jobs with progress tracking via Redis. |
| |
| Run worker with: celery -A tasks worker --loglevel=info |
| """ |
|
|
| import logging |
| import os |
| import shutil |
| import datetime |
| import ssl |
| import sys |
| import traceback |
| from pathlib import Path |
| from typing import Any |
| from zoneinfo import ZoneInfo |
|
|
| |
| APP_DIR = Path(__file__).parent.absolute() |
| if str(APP_DIR) not in sys.path: |
| sys.path.insert(0, str(APP_DIR)) |
|
|
| try: |
| from dotenv import load_dotenv |
| load_dotenv() |
| except ImportError: |
| pass |
|
|
| from celery import Celery, current_task |
|
|
| logger = logging.getLogger(__name__) |
| IST = ZoneInfo("Asia/Kolkata") |
|
|
| def _now_ist() -> datetime.datetime: |
| return datetime.datetime.now(IST).replace(tzinfo=None) |
|
|
| def _env_int(name: str, default: int | None = None, *, minimum: int | None = None) -> int | None: |
| raw = os.environ.get(name) |
| if raw is None: |
| return default |
| try: |
| value = int(raw) |
| if minimum is not None and value < minimum: |
| return default |
| return value |
| except ValueError: |
| return default |
|
|
| |
| REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0") |
|
|
| |
| celery_app = Celery( |
| "ich_tasks", |
| broker=REDIS_URL, |
| backend=REDIS_URL, |
| ) |
|
|
| |
| ssl_config = None |
| redis_backend_ssl = None |
| if REDIS_URL.startswith("rediss://"): |
| ssl_config = {"ssl_cert_reqs": ssl.CERT_NONE} |
| redis_backend_ssl = {"ssl_cert_reqs": ssl.CERT_NONE} |
|
|
| celery_app.conf.update( |
| broker_use_ssl=ssl_config, |
| redis_backend_use_ssl=redis_backend_ssl, |
| task_serializer="json", |
| accept_content=["json"], |
| result_serializer="json", |
| timezone="Asia/Kolkata", |
| enable_utc=False, |
| task_track_started=True, |
| task_time_limit=3600, |
| task_soft_time_limit=3300, |
| result_expires=86400, |
| ) |
|
|
| extra_conf: dict[str, Any] = {} |
| worker_concurrency = _env_int("ICH_CELERY_CONCURRENCY", None, minimum=1) |
| worker_prefetch = _env_int("ICH_CELERY_PREFETCH_MULTIPLIER", None, minimum=1) |
| if worker_concurrency is not None: |
| extra_conf["worker_concurrency"] = worker_concurrency |
| if worker_prefetch is not None: |
| extra_conf["worker_prefetch_multiplier"] = worker_prefetch |
| if extra_conf: |
| celery_app.conf.update(**extra_conf) |
|
|
| def _iter_batches(items: list[str], batch_size: int) -> list[list[str]]: |
| return [items[i:i + batch_size] for i in range(0, len(items), batch_size)] |
|
|
|
|
| @celery_app.task(bind=True, name="tasks.process_dicom_batch") |
| def process_dicom_batch( |
| self, |
| batch_id: str, |
| dcm_paths: list[str], |
| user_id: int, |
| temp_dir: str | None = None, |
| ) -> dict[str, Any]: |
| """ |
| Process a batch of DICOM files asynchronously with progress tracking. |
| |
| Args: |
| batch_id: Unique identifier for this batch job |
| dcm_paths: List of DICOM file paths to process |
| user_id: User ID for audit and data isolation |
| temp_dir: Optional temporary directory to clean up after |
| |
| Returns: |
| Dictionary with final batch status and results matching frontend expectations |
| """ |
| |
| |
| try: |
| |
| if str(APP_DIR) not in sys.path: |
| sys.path.insert(0, str(APP_DIR)) |
| logger.info(f"Inserted APP_DIR into sys.path: {APP_DIR}") |
| else: |
| logger.info(f"APP_DIR already in sys.path: {APP_DIR}") |
|
|
| logger.info(f"tasks.py APP_DIR={APP_DIR}") |
| logger.info(f"sys.path (first 10): {sys.path[:10]}") |
| |
| try: |
| files = [p.name for p in Path(APP_DIR).iterdir() if p.exists()] |
| logger.info(f"APP_DIR contents: {files[:50]}") |
| except Exception as _e: |
| logger.warning(f"Could not list APP_DIR contents: {_e}") |
|
|
| from app_new import app, _run_inference_on_dcm |
| from auth_utils import log_audit |
| from models import ScreeningUpload, db |
| except Exception as e: |
| logger.error("Failed importing application modules inside Celery worker:\n" + traceback.format_exc()) |
| raise |
|
|
| total = len(dcm_paths) |
| succeeded_ids = [] |
| failed_ids = [] |
| started_at = _now_ist().isoformat() |
|
|
| logger.info(f"Batch {batch_id} starting: {total} files for user {user_id}") |
|
|
| try: |
| with app.app_context(): |
| use_gpu_batch = False |
| batch_size = 1 |
| _infer_images_batch = None |
| _persist_inference_result = None |
| try: |
| from app_new import ( |
| GPU_BATCH_SIZE, |
| _gpu_batch_ready, |
| _infer_images_batch, |
| _persist_inference_result, |
| ) |
| use_gpu_batch = _gpu_batch_ready() and total > 1 |
| batch_size = max(1, GPU_BATCH_SIZE) |
| except Exception: |
| use_gpu_batch = False |
|
|
| if use_gpu_batch and _infer_images_batch and _persist_inference_result: |
| logger.info( |
| "GPU batch inference enabled (size=%s); per-image traces are skipped.", |
| batch_size, |
| ) |
| processed = 0 |
| revoked = False |
| for chunk in _iter_batches(dcm_paths, batch_size): |
| if revoked: |
| break |
|
|
| paths = [Path(p) for p in chunk] |
| upload_records: list[ScreeningUpload] = [] |
| for path in paths: |
| request_ctx = current_task.request |
| is_revoked = bool(getattr(request_ctx, "is_revoked", False)) or bool( |
| getattr(request_ctx, "revoked", False) |
| ) |
| if is_revoked: |
| logger.info(f"Batch {batch_id} revoked, stopping") |
| revoked = True |
| break |
|
|
| upload_record = ScreeningUpload( |
| user_id=user_id, |
| file_name=path.name, |
| original_filename=path.name, |
| file_size=path.stat().st_size if path.exists() else None, |
| file_path=str(path), |
| processing_status="processing", |
| ) |
| db.session.add(upload_record) |
| db.session.commit() |
| upload_records.append(upload_record) |
|
|
| if revoked: |
| break |
|
|
| try: |
| batch_results = _infer_images_batch(paths) |
| except Exception as exc: |
| logger.error( |
| f"Batch {batch_id}: GPU batch inference failed — {exc}", |
| exc_info=True, |
| ) |
| for path, upload_record in zip(paths, upload_records, strict=False): |
| image_id = path.stem |
| db.session.rollback() |
| upload_record.processing_status = "failed" |
| try: |
| db.session.commit() |
| except Exception: |
| db.session.rollback() |
| failed_ids.append(image_id) |
| processed += 1 |
| self.update_state( |
| state="PROGRESS", |
| meta={ |
| "batch_id": batch_id, |
| "user_id": user_id, |
| "status": "running", |
| "total": total, |
| "processed": processed, |
| "succeeded": len(succeeded_ids), |
| "failed_ids": list(failed_ids), |
| "image_ids": list(succeeded_ids), |
| "current_file": "", |
| "started_at": started_at, |
| "finished_at": None, |
| "error": None, |
| "temp_dir": temp_dir, |
| }, |
| ) |
| continue |
|
|
| for (path, upload_record), (img_rgb, inference) in zip( |
| zip(paths, upload_records, strict=False), |
| batch_results, |
| strict=False, |
| ): |
| image_id = path.stem |
| self.update_state( |
| state="PROGRESS", |
| meta={ |
| "batch_id": batch_id, |
| "user_id": user_id, |
| "status": "running", |
| "total": total, |
| "processed": processed, |
| "succeeded": len(succeeded_ids), |
| "failed_ids": list(failed_ids), |
| "image_ids": list(succeeded_ids), |
| "current_file": image_id, |
| "started_at": started_at, |
| "finished_at": None, |
| "error": None, |
| "temp_dir": temp_dir, |
| }, |
| ) |
|
|
| try: |
| report = _persist_inference_result( |
| image_id, |
| user_id, |
| upload_record.id, |
| img_rgb, |
| inference, |
| ) |
| if report: |
| upload_record.processing_status = "completed" |
| db.session.commit() |
| succeeded_ids.append(image_id) |
| else: |
| upload_record.processing_status = "failed" |
| db.session.commit() |
| failed_ids.append(image_id) |
| except Exception as exc: |
| logger.error(f"Batch {batch_id}: failed {image_id} — {exc}") |
| db.session.rollback() |
| upload_record.processing_status = "failed" |
| try: |
| db.session.commit() |
| except Exception: |
| db.session.rollback() |
| failed_ids.append(image_id) |
|
|
| processed += 1 |
| self.update_state( |
| state="PROGRESS", |
| meta={ |
| "batch_id": batch_id, |
| "user_id": user_id, |
| "status": "running", |
| "total": total, |
| "processed": processed, |
| "succeeded": len(succeeded_ids), |
| "failed_ids": list(failed_ids), |
| "image_ids": list(succeeded_ids), |
| "current_file": "", |
| "started_at": started_at, |
| "finished_at": None, |
| "error": None, |
| "temp_dir": temp_dir, |
| }, |
| ) |
| else: |
| for i, path_str in enumerate(dcm_paths, 1): |
| |
| request_ctx = current_task.request |
| is_revoked = bool(getattr(request_ctx, "is_revoked", False)) or bool( |
| getattr(request_ctx, "revoked", False) |
| ) |
| if is_revoked: |
| logger.info(f"Batch {batch_id} revoked, stopping") |
| break |
|
|
| path = Path(path_str) |
| image_id = path.stem |
|
|
| upload_record = ScreeningUpload( |
| user_id=user_id, |
| file_name=path.name, |
| original_filename=path.name, |
| file_size=path.stat().st_size if path.exists() else None, |
| file_path=str(path), |
| processing_status="processing", |
| ) |
| db.session.add(upload_record) |
| db.session.commit() |
|
|
| |
| self.update_state( |
| state="PROGRESS", |
| meta={ |
| "batch_id": batch_id, |
| "user_id": user_id, |
| "status": "running", |
| "total": total, |
| "processed": i - 1, |
| "succeeded": len(succeeded_ids), |
| "failed_ids": list(failed_ids), |
| "image_ids": list(succeeded_ids), |
| "current_file": image_id, |
| "started_at": started_at, |
| "finished_at": None, |
| "error": None, |
| "temp_dir": temp_dir, |
| }, |
| ) |
|
|
| try: |
| report, _ = _run_inference_on_dcm(path, user_id, upload_record.id) |
| if report: |
| upload_record.processing_status = "completed" |
| db.session.commit() |
| succeeded_ids.append(image_id) |
| else: |
| upload_record.processing_status = "failed" |
| db.session.commit() |
| failed_ids.append(image_id) |
| except Exception as e: |
| logger.error(f"Batch {batch_id}: failed {image_id} — {e}") |
| db.session.rollback() |
| upload_record.processing_status = "failed" |
| try: |
| db.session.commit() |
| except Exception: |
| db.session.rollback() |
| failed_ids.append(image_id) |
|
|
| |
| self.update_state( |
| state="PROGRESS", |
| meta={ |
| "batch_id": batch_id, |
| "user_id": user_id, |
| "status": "running", |
| "total": total, |
| "processed": i, |
| "succeeded": len(succeeded_ids), |
| "failed_ids": list(failed_ids), |
| "image_ids": list(succeeded_ids), |
| "current_file": "", |
| "started_at": started_at, |
| "finished_at": None, |
| "error": None, |
| "temp_dir": temp_dir, |
| }, |
| ) |
|
|
| |
| if temp_dir and Path(temp_dir).exists(): |
| try: |
| shutil.rmtree(temp_dir, ignore_errors=True) |
| logger.info(f"Cleaned up temp_dir: {temp_dir}") |
| except Exception as e: |
| logger.warning(f"Failed to clean temp_dir {temp_dir}: {e}") |
|
|
| |
| with app.app_context(): |
| audit_status = "success" if len(failed_ids) == 0 else "partial" |
| log_audit( |
| "batch_processing_completed", |
| user_id=user_id, |
| details=f"batch_id={batch_id}, processed={total}, succeeded={len(succeeded_ids)}, failed={len(failed_ids)}", |
| status=audit_status, |
| ) |
|
|
| |
| result = { |
| "batch_id": batch_id, |
| "user_id": user_id, |
| "status": "completed", |
| "total": total, |
| "processed": total, |
| "succeeded": len(succeeded_ids), |
| "failed_ids": list(failed_ids), |
| "image_ids": list(succeeded_ids), |
| "current_file": "", |
| "started_at": started_at, |
| "finished_at": _now_ist().isoformat(), |
| "error": None, |
| "temp_dir": temp_dir, |
| } |
|
|
| logger.info( |
| f"Batch {batch_id} complete: {len(succeeded_ids)}/{total} succeeded, " |
| f"{len(failed_ids)} failed" |
| ) |
| return result |
|
|
| except Exception as e: |
| logger.error(f"Batch {batch_id} error: {e}", exc_info=True) |
| with app.app_context(): |
| log_audit( |
| "batch_processing_failed", |
| user_id=user_id, |
| details=f"batch_id={batch_id}, error={str(e)}", |
| status="failure", |
| ) |
| raise |
|
|
|
|
| @celery_app.task(name="tasks.health_check") |
| def health_check() -> str: |
| """Simple health check task for monitoring.""" |
| return "Celery worker is healthy" |
|
|