Commit ·
87840bf
1
Parent(s): 19544a3
Fix build: minimal deps + mock inference, keep new UI & orchestrator
Browse files- Cargo.toml +17 -5
- src/inference/qwen.rs +10 -99
Cargo.toml
CHANGED
|
@@ -6,20 +6,32 @@ description = "Zero-trust medical AI triage gateway – AMD Hackathon 2026"
|
|
| 6 |
|
| 7 |
[dependencies]
|
| 8 |
tokio = { version = "1", features = ["full"] }
|
| 9 |
-
axum = { version = "0.
|
| 10 |
-
tower
|
|
|
|
| 11 |
|
| 12 |
serde = { version = "1.0", features = ["derive"] }
|
| 13 |
serde_json = "1.0"
|
| 14 |
|
| 15 |
tracing = "0.1"
|
| 16 |
-
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
| 17 |
anyhow = "1.0"
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
hex = "0.4"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
#
|
| 22 |
-
|
| 23 |
|
| 24 |
[profile.release]
|
|
|
|
|
|
|
| 25 |
opt-level = 3
|
|
|
|
| 6 |
|
| 7 |
[dependencies]
|
| 8 |
tokio = { version = "1", features = ["full"] }
|
| 9 |
+
axum = { version = "0.7", features = ["macros"] }
|
| 10 |
+
tower = "0.4"
|
| 11 |
+
tower-http = { version = "0.5", features = ["trace", "cors"] }
|
| 12 |
|
| 13 |
serde = { version = "1.0", features = ["derive"] }
|
| 14 |
serde_json = "1.0"
|
| 15 |
|
| 16 |
tracing = "0.1"
|
| 17 |
+
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
|
| 18 |
anyhow = "1.0"
|
| 19 |
|
| 20 |
+
# PII Shield
|
| 21 |
+
regex = "1"
|
| 22 |
+
|
| 23 |
+
# Web3 (CID + Base)
|
| 24 |
+
sha2 = "0.10"
|
| 25 |
hex = "0.4"
|
| 26 |
+
cid = "0.11"
|
| 27 |
+
alloy = { version = "0.7", features = ["full"] }
|
| 28 |
+
alloy-provider = "0.7"
|
| 29 |
+
alloy-signer-local = "0.7"
|
| 30 |
|
| 31 |
+
# Environment
|
| 32 |
+
dotenvy = "0.15"
|
| 33 |
|
| 34 |
[profile.release]
|
| 35 |
+
lto = true
|
| 36 |
+
codegen-units = 1
|
| 37 |
opt-level = 3
|
src/inference/qwen.rs
CHANGED
|
@@ -1,102 +1,13 @@
|
|
| 1 |
-
use anyhow::
|
| 2 |
-
use candle_core::{DType, Device, Tensor};
|
| 3 |
-
use candle_nn::VarBuilder;
|
| 4 |
-
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
| 5 |
-
use candle_transformers::models::qwen2::{Config, Model};
|
| 6 |
-
use hf_hub::api::sync::Api;
|
| 7 |
-
use tokenizers::Tokenizer;
|
| 8 |
-
use std::sync::Arc;
|
| 9 |
-
use tokio::sync::OnceCell;
|
| 10 |
-
|
| 11 |
-
static MODEL: OnceCell<Arc<LoadedModel>> = OnceCell::const_new();
|
| 12 |
-
|
| 13 |
-
struct LoadedModel {
|
| 14 |
-
model: Model,
|
| 15 |
-
tokenizer: Tokenizer,
|
| 16 |
-
device: Device,
|
| 17 |
-
model_name: String,
|
| 18 |
-
}
|
| 19 |
-
|
| 20 |
-
async fn load_model() -> Result<Arc<LoadedModel>> {
|
| 21 |
-
MODEL
|
| 22 |
-
.get_or_try_init(|| async {
|
| 23 |
-
let use_7b = std::env::var("FORCE_0_5B").unwrap_or_default() != "1";
|
| 24 |
-
let (model_id, model_name) = if use_7b {
|
| 25 |
-
("Qwen/Qwen2.5-7B-Instruct", "7B")
|
| 26 |
-
} else {
|
| 27 |
-
("Qwen/Qwen2.5-0.5B-Instruct", "0.5B")
|
| 28 |
-
};
|
| 29 |
-
|
| 30 |
-
let device = if std::env::var("ENABLE_ROCM").unwrap_or_default() == "1" {
|
| 31 |
-
Device::new_hip(0).unwrap_or_else(|e| {
|
| 32 |
-
tracing::warn!("HIP device not available: {}; falling back to CPU", e);
|
| 33 |
-
Device::Cpu
|
| 34 |
-
})
|
| 35 |
-
} else {
|
| 36 |
-
Device::Cpu
|
| 37 |
-
};
|
| 38 |
-
|
| 39 |
-
tracing::info!("Loading model {} on {:?}", model_id, device);
|
| 40 |
-
|
| 41 |
-
let api = Api::new()?;
|
| 42 |
-
let repo = api.model(model_id.to_string());
|
| 43 |
-
let model_path = repo.get("model.safetensors")?;
|
| 44 |
-
let config_path = repo.get("config.json")?;
|
| 45 |
-
let tokenizer_path = repo.get("tokenizer.json")?;
|
| 46 |
-
|
| 47 |
-
let config: Config = serde_json::from_reader(std::fs::File::open(config_path)?)?;
|
| 48 |
-
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device)? };
|
| 49 |
-
let model = Model::new(&config, vb)?;
|
| 50 |
-
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(|e| anyhow::anyhow!(e))?;
|
| 51 |
-
|
| 52 |
-
Ok(Arc::new(LoadedModel { model, tokenizer, device, model_name: model_name.to_string() }))
|
| 53 |
-
})
|
| 54 |
-
.await
|
| 55 |
-
.map(Arc::clone)
|
| 56 |
-
}
|
| 57 |
|
|
|
|
| 58 |
/// Returns (generated_text, model_used, device_info)
|
| 59 |
-
pub async fn generate(
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
k: 50,
|
| 68 |
-
p: 0.9,
|
| 69 |
-
temperature: 0.7,
|
| 70 |
-
});
|
| 71 |
-
let eos_token_id = loaded.tokenizer.token_to_id("<|im_end|>").unwrap_or(151643);
|
| 72 |
-
let max_new_tokens = 250;
|
| 73 |
-
let mut generated_text = String::new();
|
| 74 |
-
|
| 75 |
-
// Candle currently recomputes full attention for each token.
|
| 76 |
-
// A KV cache would speed this up and is the first post‑hackathon optimisation.
|
| 77 |
-
// For real‑time streaming (SSE), the loop can yield tokens as they are sampled.
|
| 78 |
-
for _ in 0..max_new_tokens {
|
| 79 |
-
let logits = loaded.model.forward(&output_ids)?.squeeze(1)?;
|
| 80 |
-
let next_token = logits_processor.sample(&logits)?;
|
| 81 |
-
if next_token == eos_token_id {
|
| 82 |
-
break;
|
| 83 |
-
}
|
| 84 |
-
output_ids = Tensor::cat(&[output_ids, next_token.unsqueeze(0)?.unsqueeze(0)?], 1)?;
|
| 85 |
-
if let Ok(text) = loaded.tokenizer.decode(&[next_token as u32], false) {
|
| 86 |
-
generated_text.push_str(&text);
|
| 87 |
-
}
|
| 88 |
-
}
|
| 89 |
-
|
| 90 |
-
let device_info = format!("{:?}", loaded.device);
|
| 91 |
-
if generated_text.is_empty() {
|
| 92 |
-
Ok(("Unable to generate output.".to_string(), loaded.model_name.clone(), device_info))
|
| 93 |
-
} else {
|
| 94 |
-
Ok((generated_text.trim().to_string(), loaded.model_name.clone(), device_info))
|
| 95 |
-
}
|
| 96 |
-
}
|
| 97 |
-
Err(e) => {
|
| 98 |
-
tracing::warn!("Model load failed: {}; falling back to mock", e);
|
| 99 |
-
Ok(("Triage result: non‑urgent (mock – model unavailable)".to_string(), "mock".to_string(), "CPU (fallback)".to_string()))
|
| 100 |
-
}
|
| 101 |
-
}
|
| 102 |
}
|
|
|
|
| 1 |
+
use anyhow::Result;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
/// Mock inference for local testing / HF Space CPU.
|
| 4 |
/// Returns (generated_text, model_used, device_info)
|
| 5 |
+
pub async fn generate(_redacted_prompt: &str) -> Result<(String, String, String)> {
|
| 6 |
+
tracing::info!("[MOCK] Inference skipped – returning placeholder");
|
| 7 |
+
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
| 8 |
+
Ok((
|
| 9 |
+
"Triage result: non‑urgent (mock)".to_string(),
|
| 10 |
+
"mock".to_string(),
|
| 11 |
+
"CPU (mock)".to_string(),
|
| 12 |
+
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
}
|