brainworm2024 commited on
Commit
99f62cc
·
0 Parent(s):

Deploy RustVital-AMD to HF Space

Browse files
.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Rust build artifacts
2
+ target/
3
+ debug/
4
+ release/
5
+
6
+ # Environment (contains private key)
7
+ .env
8
+
9
+ # IDE files
10
+ .vscode/
11
+ .idea/
12
+ *.swp
13
+ *~
14
+
15
+ # OS files
16
+ .DS_Store
17
+ Thumbs.db
Cargo.lock ADDED
The diff for this file is too large to render. See raw diff
 
Cargo.toml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [package]
2
+ name = "rustvital-amd"
3
+ version = "0.1.0"
4
+ edition = "2021"
5
+ description = "Zero-trust medical AI triage gateway – AMD Hackathon 2026"
6
+
7
+ [dependencies]
8
+ tokio = { version = "1", features = ["full"] }
9
+ axum = { version = "0.7", features = ["macros"] }
10
+ tower = "0.4"
11
+ tower-http = { version = "0.5", features = ["trace", "cors"] }
12
+
13
+ serde = { version = "1.0", features = ["derive"] }
14
+ serde_json = "1.0"
15
+
16
+ tracing = "0.1"
17
+ tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
18
+ anyhow = "1.0"
19
+
20
+ # Candle — CPU only for now
21
+ candle-core = "0.8"
22
+ candle-nn = "0.8"
23
+ candle-transformers = "0.8"
24
+ hf-hub = "0.3"
25
+
26
+ # PII Shield
27
+ pii = "0.1"
28
+ regex = "1"
29
+
30
+ # Web3
31
+ cid = "0.11" # re-exports multihash, we'll use that
32
+ sha2 = "0.10" # manual hashing for CID
33
+ alloy = { version = "0.7", features = ["full"] }
34
+ alloy-provider = "0.7"
35
+ alloy-signer-local = "0.7"
36
+
37
+ # Environment
38
+ dotenvy = "0.15"
39
+
40
+ [profile.release]
41
+ lto = true
42
+ codegen-units = 1
43
+ opt-level = 3
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Build stage
2
+ FROM rust:1.78-slim-bookworm as builder
3
+ WORKDIR /app
4
+ COPY Cargo.toml Cargo.lock* ./
5
+ COPY src src
6
+ RUN cargo build --release
7
+
8
+ # Runtime stage
9
+ FROM debian:bookworm-slim
10
+ RUN apt-get update && apt-get install -y ca-certificates && rm -rf /var/lib/apt/lists/*
11
+ COPY --from=builder /app/target/release/rustvital-amd /usr/local/bin/rustvital-amd
12
+ EXPOSE 7860
13
+ ENV RUST_LOG=info
14
+ ENV PRIVATE_KEY= # Set via HF Space secrets, not here
15
+ CMD ["rustvital-amd"]
README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: RustVital-AMD
3
+ emoji: 🏥
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ # RustVital-AMD: Zero‑Trust Medical AI Triage Gateway
11
+
12
+ Pure Rust + AMD ROCm + Web3.
13
+
14
+ - API endpoint: `/triage` (POST)
15
+ - PII redaction before GPU inference
16
+ - Filecoin CID + Base L2 audit trail
17
+
18
+ **Currently in MVP stage – mock inference, real blockchain.**
19
+ Powered by Candle, Axum, Alloy.
src/handlers/mod.rs ADDED
@@ -0,0 +1 @@
 
 
1
+ pub mod triage;
src/handlers/triage.rs ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use axum::{response::Json, http::StatusCode};
2
+ use serde::{Deserialize, Serialize};
3
+ use tracing::instrument;
4
+
5
+ use crate::shield;
6
+ use crate::inference;
7
+ use crate::web3;
8
+
9
+ #[derive(Debug, Deserialize)]
10
+ pub struct TriageRequest {
11
+ pub patient_note: String,
12
+ pub consent_hash: String,
13
+ }
14
+
15
+ #[derive(Debug, Serialize)]
16
+ pub struct TriageResponse {
17
+ pub triage_result: String,
18
+ pub transaction_hash: String,
19
+ /// Redacted text sent to the model (for audit/demo)
20
+ pub redacted_prompt: String,
21
+ /// PII map (only for verification; never in production)
22
+ pub pii_map: Vec<shield::redact::PiiMatch>,
23
+ }
24
+
25
+ #[instrument(skip_all)]
26
+ pub async fn handle(
27
+ Json(payload): Json<TriageRequest>,
28
+ ) -> Result<Json<TriageResponse>, (StatusCode, String)> {
29
+ tracing::info!("Received triage request (consent_hash: {})", payload.consent_hash);
30
+
31
+ // 1. Zero‑Trust Shield: strip PII
32
+ let (redacted_note, pii_matches) = shield::redact::redact_pii(&payload.patient_note);
33
+
34
+ // 2. Inference on redacted text (GPU never sees PII)
35
+ let triage_result = inference::qwen::generate(&redacted_note)
36
+ .await
37
+ .map_err(|e| {
38
+ tracing::error!("Inference failed: {:?}", e);
39
+ (StatusCode::INTERNAL_SERVER_ERROR, "Inference engine error".into())
40
+ })?;
41
+
42
+ // 3. Filecoin CID (immutable record of redacted prompt + result)
43
+ let cid_input = format!("{}||{}", redacted_note, triage_result);
44
+ let cid = web3::filecoin::generate_cid(&cid_input)
45
+ .map_err(|e| {
46
+ tracing::error!("CID generation failed: {:?}", e);
47
+ (StatusCode::INTERNAL_SERVER_ERROR, "CID error".into())
48
+ })?;
49
+
50
+ // 4. Base L2 transaction (posts the CID)
51
+ let tx_hash = web3::base_tx::commit_cid(&cid)
52
+ .await
53
+ .map_err(|e| {
54
+ tracing::error!("Base L2 transaction failed: {:?}", e);
55
+ (StatusCode::INTERNAL_SERVER_ERROR, "Blockchain error".into())
56
+ })?;
57
+
58
+ Ok(Json(TriageResponse {
59
+ triage_result,
60
+ transaction_hash: tx_hash,
61
+ redacted_prompt: redacted_note,
62
+ pii_map: pii_matches,
63
+ }))
64
+ }
src/inference/mod.rs ADDED
@@ -0,0 +1 @@
 
 
1
+ pub mod qwen;
src/inference/qwen.rs ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ use anyhow::Result;
2
+
3
+ /// Mock inference for local testing.
4
+ /// In production (AMD Cloud), this will load the Qwen-72B model via Candle + ROCm.
5
+ pub async fn generate(_redacted_prompt: &str) -> Result<String> {
6
+ tracing::info!("[MOCK] GPU inference skipped — returning placeholder");
7
+ // Simulate some processing
8
+ tokio::time::sleep(std::time::Duration::from_millis(10)).await;
9
+ Ok("Triage result: non‑urgent (mock)".to_string())
10
+ }
src/lib.rs ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ pub mod handlers;
2
+ pub mod inference;
3
+ pub mod shield;
4
+ pub mod web3;
src/main.rs ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use axum::{routing::post, Router};
2
+ use tower_http::trace::TraceLayer;
3
+ use tracing_subscriber::EnvFilter;
4
+
5
+ // Use the library crate to get the handler
6
+ use rustvital_amd::handlers;
7
+
8
+ #[tokio::main]
9
+ async fn main() -> anyhow::Result<()> {
10
+ // Load .env file (ignore if missing)
11
+ dotenvy::dotenv().ok();
12
+
13
+ // Initialise tracing (respects RUST_LOG from env)
14
+ tracing_subscriber::fmt()
15
+ .with_env_filter(EnvFilter::from_default_env().add_directive("rustvital_amd=debug".parse()?))
16
+ .init();
17
+
18
+ tracing::info!("Starting RustVital-AMD server");
19
+
20
+ let app = Router::new()
21
+ .route("/triage", post(handlers::triage::handle))
22
+ .layer(TraceLayer::new_for_http());
23
+
24
+ let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
25
+ axum::serve(listener, app).await?;
26
+ Ok(())
27
+ }
src/shield/mod.rs ADDED
@@ -0,0 +1 @@
 
 
1
+ pub mod redact;
src/shield/redact.rs ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use tracing::instrument;
2
+ use serde::{Serialize, Deserialize};
3
+
4
+ #[derive(Debug, Clone, Serialize, Deserialize)]
5
+ pub struct PiiMatch {
6
+ pub entity_type: String,
7
+ pub original: String,
8
+ pub placeholder: String,
9
+ }
10
+
11
+ #[instrument(skip_all)]
12
+ pub fn redact_pii(raw_text: &str) -> (String, Vec<PiiMatch>) {
13
+ let custom_patterns: Vec<(&str, &str)> = vec![
14
+ (r"\b(?:Dr\.|Dr|Professor|Prof\.)\s+[A-Z][a-z]+\b", "PROVIDER_NAME"),
15
+ // Match 2+ capitalized words (e.g., "John Smith", "Patient John Smith")
16
+ (r"\b[A-Z][a-z]+(?: [A-Z][a-z]+)+\b", "PERSON_NAME"),
17
+ (r"\b\d{1,2}/\d{1,2}/\d{2,4}\b", "DATE"),
18
+ (r"\b\d{3}-\d{2}-\d{4}\b", "SSN"),
19
+ (r"\b\d{10}\b", "PHONE"),
20
+ (r"\b[A-Z]{2}\d{2}\s?\d{2}\s?\d{2}\s?\d\b", "NHS_NUMBER"),
21
+ (r"\b\d{3}-\d{4}\b", "ZIP"),
22
+ (r"\b\d{2,3}\s?(?:years?|yo)\b", "AGE"),
23
+ (r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b", "EMAIL"),
24
+ ];
25
+
26
+ let compiled: Vec<(regex::Regex, &str)> = custom_patterns
27
+ .iter()
28
+ .filter_map(|(p, label)| regex::Regex::new(p).ok().map(|re| (re, *label)))
29
+ .collect();
30
+
31
+ let mut pii_matches = Vec::new();
32
+ let mut redacted = raw_text.to_string();
33
+ let mut counter: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
34
+
35
+ let raw_bytes = raw_text.as_bytes();
36
+ let mut all_matches: Vec<(usize, usize, &regex::Regex, &str)> = Vec::new();
37
+ for (re, entity_type) in &compiled {
38
+ for mat in re.find_iter(raw_text) {
39
+ all_matches.push((mat.start(), mat.end(), re, *entity_type));
40
+ }
41
+ }
42
+
43
+ // Sort by start, then longest match first (to prefer "Patient John Smith" over "Patient John")
44
+ all_matches.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| b.1.cmp(&a.1)));
45
+
46
+ // Remove overlapping matches
47
+ let mut used = vec![false; raw_bytes.len()];
48
+ let mut filtered = Vec::new();
49
+ for (start, end, re, entity_type) in all_matches {
50
+ let range = start..end;
51
+ if used[range.clone()].iter().any(|&b| b) {
52
+ continue;
53
+ }
54
+ used[range].fill(true);
55
+ filtered.push((start, end, re, entity_type));
56
+ }
57
+
58
+ // Process in reverse order for safe replacement
59
+ filtered.sort_by(|a, b| b.0.cmp(&a.0));
60
+ for (start, end, _re, entity_type) in filtered {
61
+ let original = &raw_text[start..end];
62
+ let count = counter.entry(entity_type.to_string()).or_insert(0);
63
+ *count += 1;
64
+ let placeholder = format!("[{}_{}]", entity_type, count);
65
+ pii_matches.push(PiiMatch {
66
+ entity_type: entity_type.to_string(),
67
+ original: original.to_string(),
68
+ placeholder: placeholder.clone(),
69
+ });
70
+ redacted.replace_range(start..end, &placeholder);
71
+ }
72
+
73
+ tracing::debug!("Redacted {} PII tokens", pii_matches.len());
74
+ (redacted, pii_matches)
75
+ }
76
+
77
+ #[cfg(test)]
78
+ mod tests {
79
+ use super::*;
80
+
81
+ #[test]
82
+ fn redact_pii_basic() {
83
+ let text = "Patient John Smith, 45 yo, phone 5551234567, SSN 123-45-6789.";
84
+ let (redacted, matches) = redact_pii(text);
85
+
86
+ // All PII tokens must be gone
87
+ assert!(!redacted.contains("John"));
88
+ assert!(!redacted.contains("Smith"));
89
+ assert!(!redacted.contains("5551234567"));
90
+ assert!(!redacted.contains("123-45-6789"));
91
+
92
+ // Placeholders present
93
+ assert!(redacted.contains("[PERSON_NAME_1]"));
94
+ assert!(redacted.contains("[PHONE_1]"));
95
+ assert!(redacted.contains("[SSN_1]"));
96
+
97
+ // Verify we captured the full name
98
+ assert!(matches.iter().any(|m| m.original.contains("John Smith") || m.original.contains("Patient John Smith")));
99
+ }
100
+ }
src/web3/base_tx.rs ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use anyhow::Result;
2
+ use alloy::{
3
+ network::TransactionBuilder,
4
+ primitives::U256,
5
+ providers::{Provider, ProviderBuilder},
6
+ signers::local::PrivateKeySigner,
7
+ rpc::types::TransactionRequest,
8
+ };
9
+ use std::str::FromStr;
10
+
11
+ pub async fn commit_cid(cid: &str) -> Result<String> {
12
+ let rpc_url = "https://sepolia.base.org".parse()?;
13
+
14
+ let priv_key = std::env::var("PRIVATE_KEY")
15
+ .unwrap_or_else(|_| "0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80".to_string());
16
+
17
+ let signer = PrivateKeySigner::from_str(&priv_key)?;
18
+ let from = signer.address();
19
+ let wallet = alloy::network::EthereumWallet::from(signer);
20
+
21
+ let provider = ProviderBuilder::new()
22
+ .wallet(wallet)
23
+ .on_http(rpc_url);
24
+
25
+ let data = cid.as_bytes().to_vec();
26
+
27
+ let mut tx = TransactionRequest::default()
28
+ .with_from(from)
29
+ .with_to(from) // self‑transfer
30
+ .with_value(U256::ZERO)
31
+ .with_input(data)
32
+ .with_chain_id(84532); // <-- Base Sepolia chain ID
33
+
34
+ let nonce = provider.get_transaction_count(from).await?;
35
+ let gas_limit = provider.estimate_gas(&tx).await?;
36
+ let fee = provider.estimate_eip1559_fees(None).await?;
37
+
38
+ tx = tx
39
+ .with_nonce(nonce)
40
+ .with_gas_limit(gas_limit)
41
+ .with_max_fee_per_gas(fee.max_fee_per_gas)
42
+ .with_max_priority_fee_per_gas(fee.max_priority_fee_per_gas);
43
+
44
+ // Retry loop (handles 502 errors)
45
+ let mut attempts = 0;
46
+ loop {
47
+ match provider.send_transaction(tx.clone()).await {
48
+ Ok(pending) => {
49
+ let receipt = pending.get_receipt().await?;
50
+ let tx_hash = receipt.transaction_hash;
51
+ tracing::info!(
52
+ "Committed CID {} to Base Sepolia, tx: {:?}",
53
+ cid,
54
+ tx_hash
55
+ );
56
+ return Ok(format!("{:?}", tx_hash));
57
+ }
58
+ Err(e) if attempts < 3 && e.to_string().contains("502") => {
59
+ attempts += 1;
60
+ tokio::time::sleep(std::time::Duration::from_secs(2)).await;
61
+ continue;
62
+ }
63
+ Err(e) => return Err(anyhow::Error::new(e).context("send_transaction failed")),
64
+ }
65
+ }
66
+ }
src/web3/filecoin.rs ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use anyhow::Result;
2
+ use cid::Cid;
3
+ use sha2::{Digest, Sha256};
4
+
5
+ /// Generate a Filecoin-compatible CIDv1 (raw codec, sha2-256)
6
+ pub fn generate_cid(data: &str) -> Result<String> {
7
+ // Compute SHA-256 hash
8
+ let mut hasher = Sha256::new();
9
+ hasher.update(data.as_bytes());
10
+ let hash_bytes = hasher.finalize();
11
+
12
+ // Build multihash (code 0x12 = sha2-256, size 32)
13
+ let mh = cid::multihash::Multihash::wrap(0x12, &hash_bytes)
14
+ .map_err(|e| anyhow::anyhow!("Invalid multihash: {:?}", e))?;
15
+
16
+ // CIDv1, raw multicodec (0x55)
17
+ let cid = Cid::new_v1(0x55, mh);
18
+ Ok(cid.to_string())
19
+ }
20
+
21
+ #[cfg(test)]
22
+ mod tests {
23
+ use super::*;
24
+
25
+ #[test]
26
+ fn cid_format() {
27
+ let cid = generate_cid("hello world").unwrap();
28
+ // CIDv1 in base32 always starts with 'b'
29
+ assert!(cid.starts_with('b'), "CID should start with 'b': {}", cid);
30
+ // Must be non-empty and decodable
31
+ let decoded = Cid::try_from(cid.as_str()).unwrap();
32
+ assert_eq!(decoded.version(), cid::Version::V1);
33
+ println!("CID: {}", cid); // e.g., bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi
34
+ }
35
+ }
src/web3/mod.rs ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pub mod filecoin;
2
+ pub mod base_tx;