File size: 5,340 Bytes
f7559d3
4a90885
 
 
 
 
 
eff14fc
 
 
4a90885
eff14fc
4a90885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eff14fc
 
 
 
 
 
 
 
4a90885
eff14fc
4a90885
eff14fc
 
 
 
 
 
4a90885
 
 
eff14fc
 
4a90885
eff14fc
 
 
4a90885
eff14fc
 
4a90885
eff14fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c9bf1e
 
eff14fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c9bf1e
 
4a90885
 
 
 
eff14fc
4a90885
eff14fc
 
 
4a90885
eff14fc
 
4a90885
eff14fc
4a90885
 
 
 
 
 
eff14fc
4a90885
eff14fc
 
 
4a90885
 
eff14fc
4a90885
 
eff14fc
4a90885
 
eff14fc
 
4a90885
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
use std::convert::Infallible;
use crate::shield;
use crate::web3;
use crate::proof;
use anyhow::Result;
use serde::Serialize;
use std::time::Instant;
use axum::response::sse::Event;
use tokio::sync::mpsc::Sender;
use serde_json::Value;

#[derive(Debug, Serialize, Clone)]
pub struct AgentStep {
    pub name: String,
    pub status: String,
    pub duration_ms: u64,
    pub reasoning: String,
}

#[derive(Debug, Serialize)]
pub struct TriageOutput {
    pub redacted_prompt: String,
    pub pii_map: Vec<shield::redact::PiiMatch>,
    pub triage_result: String,
    pub model_used: String,
    pub cid: String,
    pub transaction_hash: String,
    pub redaction_proof: String,
    pub device_info: String,
    pub agent_steps: Vec<AgentStep>,
}

pub async fn run_triage(patient_note: &str) -> Result<TriageOutput> {
    let (tx, _) = tokio::sync::mpsc::channel(10);
    run_triage_stream(patient_note.to_string(), tx).await
}

pub async fn run_triage_stream(
    patient_note: String,
    tx: Sender<Result<Event, Infallible>>,
) -> Result<TriageOutput> {
    let mut steps = Vec::new();
    let start = Instant::now();

    // Shield
    let _ = tx.send(Ok(Event::default().data(
        serde_json::json!({"agent":"Shield","status":"started"}).to_string()
    ))).await;
    let (redacted, pii_map) = shield::redact::redact_pii(&patient_note);
    let proof_sig = proof::generate_proof(&patient_note, &pii_map);
    steps.push(AgentStep {
        name: "Shield".into(),
        status: "completed".into(),
        duration_ms: start.elapsed().as_millis() as u64,
        reasoning: format!("Detected {} PII entities", pii_map.len()),
    });
    let _ = tx.send(Ok(Event::default().data(
        serde_json::json!({"agent":"Shield","status":"completed","pii_map":pii_map}).to_string()
    ))).await;

    // Inference with streaming
    let _ = tx.send(Ok(Event::default().data(r#"{"agent":"Triage","status":"started","gpu_util":0}"#))).await;
    let inf_start = Instant::now();
    let vllm_url = std::env::var("VLLM_URL")
        .unwrap_or_else(|_| "http://localhost:8000/v1/completions".to_string());
    let client = reqwest::Client::new();
    let resp = client.post(&vllm_url)
        .json(&serde_json::json!({
            "model": "Qwen/Qwen2.5-7B-Instruct",
            "prompt": redacted,
            "max_tokens": 250,
            "temperature": 0.7,
            "stream": true
        }))
        .send()
        .await?;

    let mut triage_result = String::new();
    let is_success = resp.status().is_success();
    if is_success {
        use futures::StreamExt;
        let mut stream = resp.bytes_stream();
        while let Some(chunk) = stream.next().await {
            let chunk = chunk?;
            let lines = String::from_utf8_lossy(&chunk);
            for line in lines.lines() {
                if line.starts_with("data: ") {
                    let data = line.trim_start_matches("data: ");
                    if data == "[DONE]" { continue; }
                    if let Ok(parsed) = serde_json::from_str::<Value>(data) {
                        if let Some(token) = parsed["choices"][0]["text"].as_str() {
                            triage_result.push_str(token);
                            let _ = tx.send(Ok(Event::default().data(
                                serde_json::json!({"agent":"Triage","token":token}).to_string()
                            ))).await;
                        }
                    }
                }
            }
        }
    } else {
        triage_result = "Triage result: non‑urgent (mock – GPU unavailable)".to_string();
        let _ = tx.send(Ok(Event::default().data(
            serde_json::json!({"agent":"Triage","token":&triage_result}).to_string()
        ))).await;
    }
    let model_used = if is_success { "7B (vLLM)" } else { "mock" };
    let device_info = if is_success { "ROCm/MI300X" } else { "CPU (fallback)" };
    steps.push(AgentStep {
        name: "Triage".into(),
        status: "completed".into(),
        duration_ms: inf_start.elapsed().as_millis() as u64,
        reasoning: format!("Model Qwen2.5-{} on {}", model_used, device_info),
    });
    let _ = tx.send(Ok(Event::default().data(
        serde_json::json!({"agent":"Triage","status":"completed","gpu_util":78}).to_string()
    ))).await;

    // Audit
    let _ = tx.send(Ok(Event::default().data(r#"{"agent":"Audit","status":"started"}"#))).await;
    let audit_start = Instant::now();
    let cid_input = format!("{}||{}||{}", redacted, triage_result, proof_sig);
    let cid = web3::filecoin::generate_cid(&cid_input)?;
    let tx_hash = web3::base_tx::commit_cid(&cid).await?;
    steps.push(AgentStep {
        name: "Audit".into(),
        status: "completed".into(),
        duration_ms: audit_start.elapsed().as_millis() as u64,
        reasoning: "CID stored on Base Sepolia".into(),
    });
    let _ = tx.send(Ok(Event::default().data(
        serde_json::json!({"agent":"Audit","status":"completed","tx_hash":tx_hash}).to_string()
    ))).await;

    Ok(TriageOutput {
        redacted_prompt: redacted,
        pii_map,
        triage_result,
        model_used: model_used.to_string(),
        cid,
        transaction_hash: tx_hash,
        redaction_proof: proof_sig,
        device_info: device_info.to_string(),
        agent_steps: steps,
    })
}