""" Celery task workers for async inference and batch processing. Handles long-running DICOM processing jobs with progress tracking via Redis. Run worker with: celery -A tasks worker --loglevel=info """ import logging import os import shutil import datetime import ssl import sys import traceback from pathlib import Path from typing import Any from zoneinfo import ZoneInfo # Ensure the app directory is in the Python path so imports work in worker processes APP_DIR = Path(__file__).parent.absolute() if str(APP_DIR) not in sys.path: sys.path.insert(0, str(APP_DIR)) try: from dotenv import load_dotenv load_dotenv() except ImportError: pass from celery import Celery, current_task logger = logging.getLogger(__name__) IST = ZoneInfo("Asia/Kolkata") def _now_ist() -> datetime.datetime: return datetime.datetime.now(IST).replace(tzinfo=None) def _env_int(name: str, default: int | None = None, *, minimum: int | None = None) -> int | None: raw = os.environ.get(name) if raw is None: return default try: value = int(raw) if minimum is not None and value < minimum: return default return value except ValueError: return default # Extract Redis URL from environment REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0") # Initialize Celery app celery_app = Celery( "ich_tasks", broker=REDIS_URL, backend=REDIS_URL, ) # Configure Celery with SSL support for Upstash Redis ssl_config = None redis_backend_ssl = None if REDIS_URL.startswith("rediss://"): ssl_config = {"ssl_cert_reqs": ssl.CERT_NONE} redis_backend_ssl = {"ssl_cert_reqs": ssl.CERT_NONE} celery_app.conf.update( broker_use_ssl=ssl_config, redis_backend_use_ssl=redis_backend_ssl, task_serializer="json", accept_content=["json"], result_serializer="json", timezone="Asia/Kolkata", enable_utc=False, task_track_started=True, task_time_limit=3600, # 1 hour hard limit task_soft_time_limit=3300, # 55 min soft limit result_expires=86400, # 24 hours ) extra_conf: dict[str, Any] = {} worker_concurrency = _env_int("ICH_CELERY_CONCURRENCY", None, minimum=1) worker_prefetch = _env_int("ICH_CELERY_PREFETCH_MULTIPLIER", None, minimum=1) if worker_concurrency is not None: extra_conf["worker_concurrency"] = worker_concurrency if worker_prefetch is not None: extra_conf["worker_prefetch_multiplier"] = worker_prefetch if extra_conf: celery_app.conf.update(**extra_conf) def _iter_batches(items: list[str], batch_size: int) -> list[list[str]]: return [items[i:i + batch_size] for i in range(0, len(items), batch_size)] @celery_app.task(bind=True, name="tasks.process_dicom_batch") def process_dicom_batch( self, batch_id: str, dcm_paths: list[str], user_id: int, temp_dir: str | None = None, ) -> dict[str, Any]: """ Process a batch of DICOM files asynchronously with progress tracking. Args: batch_id: Unique identifier for this batch job dcm_paths: List of DICOM file paths to process user_id: User ID for audit and data isolation temp_dir: Optional temporary directory to clean up after Returns: Dictionary with final batch status and results matching frontend expectations """ # Import here to avoid circular imports. Add diagnostics to help debug # ModuleNotFoundError issues when Celery workers can't find `app_new`. try: # Ensure APP_DIR is present in sys.path for worker subprocesses if str(APP_DIR) not in sys.path: sys.path.insert(0, str(APP_DIR)) logger.info(f"Inserted APP_DIR into sys.path: {APP_DIR}") else: logger.info(f"APP_DIR already in sys.path: {APP_DIR}") logger.info(f"tasks.py APP_DIR={APP_DIR}") logger.info(f"sys.path (first 10): {sys.path[:10]}") # List files in the app dir for visibility try: files = [p.name for p in Path(APP_DIR).iterdir() if p.exists()] logger.info(f"APP_DIR contents: {files[:50]}") except Exception as _e: logger.warning(f"Could not list APP_DIR contents: {_e}") from app_new import app, _run_inference_on_dcm from auth_utils import log_audit from models import ScreeningUpload, db except Exception as e: logger.error("Failed importing application modules inside Celery worker:\n" + traceback.format_exc()) raise total = len(dcm_paths) succeeded_ids = [] failed_ids = [] started_at = _now_ist().isoformat() logger.info(f"Batch {batch_id} starting: {total} files for user {user_id}") try: with app.app_context(): use_gpu_batch = False batch_size = 1 _infer_images_batch = None _persist_inference_result = None try: from app_new import ( GPU_BATCH_SIZE, _gpu_batch_ready, _infer_images_batch, _persist_inference_result, ) use_gpu_batch = _gpu_batch_ready() and total > 1 batch_size = max(1, GPU_BATCH_SIZE) except Exception: use_gpu_batch = False if use_gpu_batch and _infer_images_batch and _persist_inference_result: logger.info( "GPU batch inference enabled (size=%s); per-image traces are skipped.", batch_size, ) processed = 0 revoked = False for chunk in _iter_batches(dcm_paths, batch_size): if revoked: break paths = [Path(p) for p in chunk] upload_records: list[ScreeningUpload] = [] for path in paths: request_ctx = current_task.request is_revoked = bool(getattr(request_ctx, "is_revoked", False)) or bool( getattr(request_ctx, "revoked", False) ) if is_revoked: logger.info(f"Batch {batch_id} revoked, stopping") revoked = True break upload_record = ScreeningUpload( user_id=user_id, file_name=path.name, original_filename=path.name, file_size=path.stat().st_size if path.exists() else None, file_path=str(path), processing_status="processing", ) db.session.add(upload_record) db.session.commit() upload_records.append(upload_record) if revoked: break try: batch_results = _infer_images_batch(paths) except Exception as exc: logger.error( f"Batch {batch_id}: GPU batch inference failed — {exc}", exc_info=True, ) for path, upload_record in zip(paths, upload_records, strict=False): image_id = path.stem db.session.rollback() upload_record.processing_status = "failed" try: db.session.commit() except Exception: db.session.rollback() failed_ids.append(image_id) processed += 1 self.update_state( state="PROGRESS", meta={ "batch_id": batch_id, "user_id": user_id, "status": "running", "total": total, "processed": processed, "succeeded": len(succeeded_ids), "failed_ids": list(failed_ids), "image_ids": list(succeeded_ids), "current_file": "", "started_at": started_at, "finished_at": None, "error": None, "temp_dir": temp_dir, }, ) continue for (path, upload_record), (img_rgb, inference) in zip( zip(paths, upload_records, strict=False), batch_results, strict=False, ): image_id = path.stem self.update_state( state="PROGRESS", meta={ "batch_id": batch_id, "user_id": user_id, "status": "running", "total": total, "processed": processed, "succeeded": len(succeeded_ids), "failed_ids": list(failed_ids), "image_ids": list(succeeded_ids), "current_file": image_id, "started_at": started_at, "finished_at": None, "error": None, "temp_dir": temp_dir, }, ) try: report = _persist_inference_result( image_id, user_id, upload_record.id, img_rgb, inference, ) if report: upload_record.processing_status = "completed" db.session.commit() succeeded_ids.append(image_id) else: upload_record.processing_status = "failed" db.session.commit() failed_ids.append(image_id) except Exception as exc: logger.error(f"Batch {batch_id}: failed {image_id} — {exc}") db.session.rollback() upload_record.processing_status = "failed" try: db.session.commit() except Exception: db.session.rollback() failed_ids.append(image_id) processed += 1 self.update_state( state="PROGRESS", meta={ "batch_id": batch_id, "user_id": user_id, "status": "running", "total": total, "processed": processed, "succeeded": len(succeeded_ids), "failed_ids": list(failed_ids), "image_ids": list(succeeded_ids), "current_file": "", "started_at": started_at, "finished_at": None, "error": None, "temp_dir": temp_dir, }, ) else: for i, path_str in enumerate(dcm_paths, 1): # Check if task was revoked (compat across Celery versions) request_ctx = current_task.request is_revoked = bool(getattr(request_ctx, "is_revoked", False)) or bool( getattr(request_ctx, "revoked", False) ) if is_revoked: logger.info(f"Batch {batch_id} revoked, stopping") break path = Path(path_str) image_id = path.stem upload_record = ScreeningUpload( user_id=user_id, file_name=path.name, original_filename=path.name, file_size=path.stat().st_size if path.exists() else None, file_path=str(path), processing_status="processing", ) db.session.add(upload_record) db.session.commit() # Update Celery task state with progress (matches _BATCHES format for frontend) self.update_state( state="PROGRESS", meta={ "batch_id": batch_id, "user_id": user_id, "status": "running", "total": total, "processed": i - 1, "succeeded": len(succeeded_ids), "failed_ids": list(failed_ids), "image_ids": list(succeeded_ids), "current_file": image_id, "started_at": started_at, "finished_at": None, "error": None, "temp_dir": temp_dir, }, ) try: report, _ = _run_inference_on_dcm(path, user_id, upload_record.id) if report: upload_record.processing_status = "completed" db.session.commit() succeeded_ids.append(image_id) else: upload_record.processing_status = "failed" db.session.commit() failed_ids.append(image_id) except Exception as e: logger.error(f"Batch {batch_id}: failed {image_id} — {e}") db.session.rollback() upload_record.processing_status = "failed" try: db.session.commit() except Exception: db.session.rollback() failed_ids.append(image_id) # Update after processing each file self.update_state( state="PROGRESS", meta={ "batch_id": batch_id, "user_id": user_id, "status": "running", "total": total, "processed": i, "succeeded": len(succeeded_ids), "failed_ids": list(failed_ids), "image_ids": list(succeeded_ids), "current_file": "", "started_at": started_at, "finished_at": None, "error": None, "temp_dir": temp_dir, }, ) # Cleanup temporary directory if provided if temp_dir and Path(temp_dir).exists(): try: shutil.rmtree(temp_dir, ignore_errors=True) logger.info(f"Cleaned up temp_dir: {temp_dir}") except Exception as e: logger.warning(f"Failed to clean temp_dir {temp_dir}: {e}") # Log final audit result with app.app_context(): audit_status = "success" if len(failed_ids) == 0 else "partial" log_audit( "batch_processing_completed", user_id=user_id, details=f"batch_id={batch_id}, processed={total}, succeeded={len(succeeded_ids)}, failed={len(failed_ids)}", status=audit_status, ) # Return final result matching _BATCHES format for frontend compatibility result = { "batch_id": batch_id, "user_id": user_id, "status": "completed", "total": total, "processed": total, "succeeded": len(succeeded_ids), "failed_ids": list(failed_ids), "image_ids": list(succeeded_ids), "current_file": "", "started_at": started_at, "finished_at": _now_ist().isoformat(), "error": None, "temp_dir": temp_dir, } logger.info( f"Batch {batch_id} complete: {len(succeeded_ids)}/{total} succeeded, " f"{len(failed_ids)} failed" ) return result except Exception as e: logger.error(f"Batch {batch_id} error: {e}", exc_info=True) with app.app_context(): log_audit( "batch_processing_failed", user_id=user_id, details=f"batch_id={batch_id}, error={str(e)}", status="failure", ) raise @celery_app.task(name="tasks.health_check") def health_check() -> str: """Simple health check task for monitoring.""" return "Celery worker is healthy"