File size: 11,652 Bytes
674fb4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
"""
Celery workers for async document ingestion
Decouples ingestion from the API request loop

KNOWN LIMITATION: Currently creates a new Neo4jStore per task.
At scale, this requires dedicated Neo4j read replicas or connection pooling
like PgBouncer (Neo4j driver handles some internal pooling, but high
concurrency can exhaust connections).
"""

from celery import Celery
from celery.schedules import crontab
from pathlib import Path
import asyncio

from ..config import settings
from ..ingestion.pipeline import IngestionPipeline
from ..core.storage import get_storage
import tempfile
import io
from ..core.neo4j_store import Neo4jStore
from ..core.llm_factory import UnifiedLLMProvider
from ..ingestion.persona_generator import PersonaGenerator
from .simulation_runner import SimulationManager

# Initialize Celery
celery_app = Celery(
    'graph_rag_workers',
    broker=settings.celery_broker_url,
    backend=settings.celery_result_backend
)

celery_app.conf.update(
    task_serializer='json',
    accept_content=['json'],
    result_serializer='json',
    timezone='UTC',
    enable_utc=True,
    task_track_started=True,
    task_time_limit=3600,  # 1 hour max
    task_soft_time_limit=3000,  # 50 minutes soft limit
)

celery_app.conf.beat_schedule = {
    'cleanup-orphan-nodes-daily': {
        'task': 'cleanup_orphan_nodes',
        'schedule': crontab(minute=0, hour=2),  # Run at 2 AM daily
    },
    'enrich-entities-daily': {
        'task': 'enrich_entities',
        'schedule': crontab(minute=30, hour=2),  # 2:30 AM daily (after cleanup)
    },
    'ontology-drift-check-daily': {
        'task': 'check_ontology_drift',
        'schedule': crontab(minute=0, hour=3),  # 3 AM daily
    },
}

from celery.signals import worker_process_init

_worker_loop = None

@worker_process_init.connect
def _init_worker_loop(**kwargs):
    global _worker_loop
    _worker_loop = asyncio.new_event_loop()
    asyncio.set_event_loop(_worker_loop)

def run_async(coro):
    """Helper to run async functions in Celery tasks using a persistent loop"""
    global _worker_loop
    if _worker_loop is not None:
        return _worker_loop.run_until_complete(coro)
    else:
        # Fallback if not running in a Celery worker process (e.g. tests)
        return asyncio.run(coro)


@celery_app.task(name='ingest_document', bind=True)
def ingest_document_task(self, file_path: str, ontology_dict: dict = None, tenant_id: str = None):
    """
    Celery task for document ingestion

    Args:
        file_path: Path to document file
        ontology_dict: Optional ontology as dictionary
        tenant_id: Tenant to scope ingestion to

    Returns:
        Extraction result as dictionary
    """
    
    async def _ingest():
        # Initialize pipeline
        graph_store = Neo4jStore()
        pipeline = IngestionPipeline(graph_store=graph_store)
        
        def progress_cb(current, total):
            self.update_state(
                state='PROCESSING', 
                meta={'file': file_path, 'current_chunk': current, 'total_chunks': total}
            )
            
        try:
            await pipeline.initialize()
            
            # Convert ontology dict if provided
            ontology = None
            if ontology_dict:
                from ..core.models import OntologySchema
                ontology = OntologySchema(**ontology_dict)
            
            # Ingest document
            storage = get_storage()
            file_bytes = storage.read_file(file_path)
            
            with tempfile.TemporaryDirectory() as temp_dir:
                temp_path = Path(temp_dir) / file_path
                temp_path.write_bytes(file_bytes)
                
                result = await pipeline.ingest_document(
                    temp_path,
                    ontology=ontology,
                    progress_callback=progress_cb,
                    tenant_id=tenant_id
                )
            
            # Convert result to dict
            return {
                "entities_count": len(result.entities),
                "relationships_count": len(result.relationships),
                "chunks_count": len(result.chunks),
                "ontology_version": result.ontology_version,
                "processing_time_seconds": result.processing_time_seconds
            }
        finally:
            await pipeline.close()
    
    # Update task state
    self.update_state(state='PROCESSING', meta={'file': file_path})
    
    try:
        result = run_async(_ingest())
        return result
    except Exception as e:
        # Return error as a plain dict — never raise.
        # Raising any exception (even builtins) can crash the Celery worker
        # when the Redis backend holds a previously corrupt task result.
        error_msg = f"{type(e).__name__}: {e}"
        return {'status': 'error', 'error': error_msg}


@celery_app.task(name='ingest_documents_batch', bind=True)
def ingest_documents_batch_task(self, file_paths: list, ontology_dict: dict = None, tenant_id: str = None):
    """
    Celery task for batch document ingestion
    
    Args:
        file_paths: List of document file paths
        ontology_dict: Optional ontology as dictionary
        
    Returns:
        List of extraction results
    """
    
    async def _ingest_batch():
        graph_store = Neo4jStore()
        pipeline = IngestionPipeline(graph_store=graph_store)
        
        try:
            await pipeline.initialize()
            
            ontology = None
            if ontology_dict:
                from ..core.models import OntologySchema
                ontology = OntologySchema(**ontology_dict)
            
            results = await pipeline.ingest_documents(
                [Path(fp) for fp in file_paths],
                ontology=ontology
            )
            
            return [
                {
                    "entities_count": len(r.entities),
                    "relationships_count": len(r.relationships),
                    "chunks_count": len(r.chunks),
                    "ontology_version": r.ontology_version,
                    "processing_time_seconds": r.processing_time_seconds
                }
                for r in results
            ]
        finally:
            await pipeline.close()
    
    self.update_state(state='PROCESSING', meta={'files_count': len(file_paths)})
    
    try:
        results = run_async(_ingest_batch())
        return results
    except Exception as e:
        error_msg = f"{type(e).__name__}: {e}"
        return {'status': 'error', 'error': error_msg}


@celery_app.task(name='cleanup_orphan_nodes')
def cleanup_orphan_nodes_task():
    """
    Background job to clean up disconnected or orphaned nodes in Neo4j.
    Scheduled via Celery Beat.
    """
    async def _clean():
        graph_store = Neo4jStore()
        await graph_store.connect()
        try:
            # Delete Entity nodes with 0 relationships
            query = """
            MATCH (n:Entity)
            WHERE size((n)--()) = 0
            DELETE n
            RETURN count(n) as deleted_count
            """
            result = await graph_store.execute_query(query)
            
            # Delete unlinked Chunks
            chunk_query = """
            MATCH (c:Chunk)
            WHERE NOT (c)<-[:CONTAINS]-(:Document) AND NOT (c)-[:MENTIONS]->(:Entity)
            DELETE c
            RETURN count(c) as deleted_chunks
            """
            chunk_res = await graph_store.execute_query(chunk_query)
            
            return {
                "status": "success", 
                "deleted_entities": result[0]["deleted_count"] if result else 0,
                "deleted_chunks": chunk_res[0]["deleted_chunks"] if chunk_res else 0
            }
        finally:
            await graph_store.disconnect()
            
    return run_async(_clean())


@celery_app.task(name='health_check')
def health_check():
    """Simple health check task"""
    return {"status": "ok", "message": "Worker is healthy"}


@celery_app.task(name='generate_personas')
def generate_personas_task(entity_type='Person'):
    '''Celery task to run the Ontology-to-Persona Pipeline asynchronously.'''
    async def async_run():
        store = Neo4jStore()
        await store.connect()
        llm = UnifiedLLMProvider()
        generator = PersonaGenerator(store, llm)
        count = await generator.generate_personas_for_type(entity_type)
        await store.disconnect()
        return {'status': 'success', 'personas_generated': count}
    return run_async(async_run())

@celery_app.task(name='run_simulation_tick')
def run_simulation_tick_task():
    '''Celery task to run a Multi-Agent Sandbox Simulation Tick (Point 4).'''
    async def async_run():
        store = Neo4jStore()
        await store.connect()
        llm = UnifiedLLMProvider()
        manager = SimulationManager(store, llm)
        actions_taken = await manager.run_simulation_tick()
        await store.disconnect()
        return {'status': 'success', 'actions_taken': actions_taken}
    return run_async(async_run())


@celery_app.task(name='enrich_entities', bind=True)
def enrich_entities_task(self, min_connections: int = 1, overwrite: bool = False):
    """
    Background task to run Entity Enricher: generate LLM profile summaries
    for all well-connected entities and persist them to Neo4j.
    Triggered automatically after ingestion and on daily schedule.
    """
    async def _run():
        from ..services.entity_enricher import EntityEnricher
        store = Neo4jStore()
        await store.connect()
        try:
            enricher = EntityEnricher(graph_store=store)
            result = await enricher.enrich_all_entities(
                min_connections=min_connections,
                overwrite=overwrite,
            )
            return {
                'status': 'success',
                'entities_enriched': result.entities_enriched,
                'entities_skipped': result.entities_skipped,
                'errors': result.errors,
                'duration_seconds': result.duration_seconds,
            }
        finally:
            await store.disconnect()

    try:
        return run_async(_run())
    except Exception as e:
        return {'status': 'error', 'error': f"{type(e).__name__}: {e}"}


@celery_app.task(name='check_ontology_drift', bind=True)
def check_ontology_drift_task(self, sample_size: int = 10):
    """
    Background task to check for ontology drift: re-samples random chunks,
    proposes a new ontology, diffs against current schema.
    Creates a pending DriftReport node in Neo4j for admin review.
    """
    async def _run():
        from ..services.ontology_drift_detector import OntologyDriftDetector
        store = Neo4jStore()
        await store.connect()
        try:
            detector = OntologyDriftDetector(graph_store=store)
            report = await detector.detect_drift(sample_size=sample_size)
            if report:
                return {
                    'status': 'success',
                    'report_id': report.id,
                    'drift_score': report.drift_score,
                    'new_entity_types': report.new_entity_types,
                    'new_relationship_types': report.new_relationship_types,
                }
            return {'status': 'no_ontology', 'message': 'No ontology found — nothing to diff against'}
        finally:
            await store.disconnect()

    try:
        return run_async(_run())
    except Exception as e:
        return {'status': 'error', 'error': f"{type(e).__name__}: {e}"}