rustvital-amd / src /orchestrator.rs
brainworm2024's picture
Fix ownership issue in orchestrator streaming
6c9bf1e
use std::convert::Infallible;
use crate::shield;
use crate::web3;
use crate::proof;
use anyhow::Result;
use serde::Serialize;
use std::time::Instant;
use axum::response::sse::Event;
use tokio::sync::mpsc::Sender;
use serde_json::Value;
#[derive(Debug, Serialize, Clone)]
pub struct AgentStep {
pub name: String,
pub status: String,
pub duration_ms: u64,
pub reasoning: String,
}
#[derive(Debug, Serialize)]
pub struct TriageOutput {
pub redacted_prompt: String,
pub pii_map: Vec<shield::redact::PiiMatch>,
pub triage_result: String,
pub model_used: String,
pub cid: String,
pub transaction_hash: String,
pub redaction_proof: String,
pub device_info: String,
pub agent_steps: Vec<AgentStep>,
}
pub async fn run_triage(patient_note: &str) -> Result<TriageOutput> {
let (tx, _) = tokio::sync::mpsc::channel(10);
run_triage_stream(patient_note.to_string(), tx).await
}
pub async fn run_triage_stream(
patient_note: String,
tx: Sender<Result<Event, Infallible>>,
) -> Result<TriageOutput> {
let mut steps = Vec::new();
let start = Instant::now();
// Shield
let _ = tx.send(Ok(Event::default().data(
serde_json::json!({"agent":"Shield","status":"started"}).to_string()
))).await;
let (redacted, pii_map) = shield::redact::redact_pii(&patient_note);
let proof_sig = proof::generate_proof(&patient_note, &pii_map);
steps.push(AgentStep {
name: "Shield".into(),
status: "completed".into(),
duration_ms: start.elapsed().as_millis() as u64,
reasoning: format!("Detected {} PII entities", pii_map.len()),
});
let _ = tx.send(Ok(Event::default().data(
serde_json::json!({"agent":"Shield","status":"completed","pii_map":pii_map}).to_string()
))).await;
// Inference with streaming
let _ = tx.send(Ok(Event::default().data(r#"{"agent":"Triage","status":"started","gpu_util":0}"#))).await;
let inf_start = Instant::now();
let vllm_url = std::env::var("VLLM_URL")
.unwrap_or_else(|_| "http://localhost:8000/v1/completions".to_string());
let client = reqwest::Client::new();
let resp = client.post(&vllm_url)
.json(&serde_json::json!({
"model": "Qwen/Qwen2.5-7B-Instruct",
"prompt": redacted,
"max_tokens": 250,
"temperature": 0.7,
"stream": true
}))
.send()
.await?;
let mut triage_result = String::new();
let is_success = resp.status().is_success();
if is_success {
use futures::StreamExt;
let mut stream = resp.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
let lines = String::from_utf8_lossy(&chunk);
for line in lines.lines() {
if line.starts_with("data: ") {
let data = line.trim_start_matches("data: ");
if data == "[DONE]" { continue; }
if let Ok(parsed) = serde_json::from_str::<Value>(data) {
if let Some(token) = parsed["choices"][0]["text"].as_str() {
triage_result.push_str(token);
let _ = tx.send(Ok(Event::default().data(
serde_json::json!({"agent":"Triage","token":token}).to_string()
))).await;
}
}
}
}
}
} else {
triage_result = "Triage result: non‑urgent (mock – GPU unavailable)".to_string();
let _ = tx.send(Ok(Event::default().data(
serde_json::json!({"agent":"Triage","token":&triage_result}).to_string()
))).await;
}
let model_used = if is_success { "7B (vLLM)" } else { "mock" };
let device_info = if is_success { "ROCm/MI300X" } else { "CPU (fallback)" };
steps.push(AgentStep {
name: "Triage".into(),
status: "completed".into(),
duration_ms: inf_start.elapsed().as_millis() as u64,
reasoning: format!("Model Qwen2.5-{} on {}", model_used, device_info),
});
let _ = tx.send(Ok(Event::default().data(
serde_json::json!({"agent":"Triage","status":"completed","gpu_util":78}).to_string()
))).await;
// Audit
let _ = tx.send(Ok(Event::default().data(r#"{"agent":"Audit","status":"started"}"#))).await;
let audit_start = Instant::now();
let cid_input = format!("{}||{}||{}", redacted, triage_result, proof_sig);
let cid = web3::filecoin::generate_cid(&cid_input)?;
let tx_hash = web3::base_tx::commit_cid(&cid).await?;
steps.push(AgentStep {
name: "Audit".into(),
status: "completed".into(),
duration_ms: audit_start.elapsed().as_millis() as u64,
reasoning: "CID stored on Base Sepolia".into(),
});
let _ = tx.send(Ok(Event::default().data(
serde_json::json!({"agent":"Audit","status":"completed","tx_hash":tx_hash}).to_string()
))).await;
Ok(TriageOutput {
redacted_prompt: redacted,
pii_map,
triage_result,
model_used: model_used.to_string(),
cid,
transaction_hash: tx_hash,
redaction_proof: proof_sig,
device_info: device_info.to_string(),
agent_steps: steps,
})
}