brainworm2024 commited on
Commit
87840bf
·
1 Parent(s): 19544a3

Fix build: minimal deps + mock inference, keep new UI & orchestrator

Browse files
Files changed (2) hide show
  1. Cargo.toml +17 -5
  2. src/inference/qwen.rs +10 -99
Cargo.toml CHANGED
@@ -6,20 +6,32 @@ description = "Zero-trust medical AI triage gateway – AMD Hackathon 2026"
6
 
7
  [dependencies]
8
  tokio = { version = "1", features = ["full"] }
9
- axum = { version = "0.8", features = ["macros"] }
10
- tower-http = { version = "0.6", features = ["trace", "cors"] }
 
11
 
12
  serde = { version = "1.0", features = ["derive"] }
13
  serde_json = "1.0"
14
 
15
  tracing = "0.1"
16
- tracing-subscriber = { version = "0.3", features = ["env-filter"] }
17
  anyhow = "1.0"
18
 
 
 
 
 
 
19
  hex = "0.4"
 
 
 
 
20
 
21
- # Minimal dependencies for now (real Candle added locally later)
22
- # candle-core = "0.8" # commented until we test locally with MI300X
23
 
24
  [profile.release]
 
 
25
  opt-level = 3
 
6
 
7
  [dependencies]
8
  tokio = { version = "1", features = ["full"] }
9
+ axum = { version = "0.7", features = ["macros"] }
10
+ tower = "0.4"
11
+ tower-http = { version = "0.5", features = ["trace", "cors"] }
12
 
13
  serde = { version = "1.0", features = ["derive"] }
14
  serde_json = "1.0"
15
 
16
  tracing = "0.1"
17
+ tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
18
  anyhow = "1.0"
19
 
20
+ # PII Shield
21
+ regex = "1"
22
+
23
+ # Web3 (CID + Base)
24
+ sha2 = "0.10"
25
  hex = "0.4"
26
+ cid = "0.11"
27
+ alloy = { version = "0.7", features = ["full"] }
28
+ alloy-provider = "0.7"
29
+ alloy-signer-local = "0.7"
30
 
31
+ # Environment
32
+ dotenvy = "0.15"
33
 
34
  [profile.release]
35
+ lto = true
36
+ codegen-units = 1
37
  opt-level = 3
src/inference/qwen.rs CHANGED
@@ -1,102 +1,13 @@
1
- use anyhow::{Context, Result};
2
- use candle_core::{DType, Device, Tensor};
3
- use candle_nn::VarBuilder;
4
- use candle_transformers::generation::{LogitsProcessor, Sampling};
5
- use candle_transformers::models::qwen2::{Config, Model};
6
- use hf_hub::api::sync::Api;
7
- use tokenizers::Tokenizer;
8
- use std::sync::Arc;
9
- use tokio::sync::OnceCell;
10
-
11
- static MODEL: OnceCell<Arc<LoadedModel>> = OnceCell::const_new();
12
-
13
- struct LoadedModel {
14
- model: Model,
15
- tokenizer: Tokenizer,
16
- device: Device,
17
- model_name: String,
18
- }
19
-
20
- async fn load_model() -> Result<Arc<LoadedModel>> {
21
- MODEL
22
- .get_or_try_init(|| async {
23
- let use_7b = std::env::var("FORCE_0_5B").unwrap_or_default() != "1";
24
- let (model_id, model_name) = if use_7b {
25
- ("Qwen/Qwen2.5-7B-Instruct", "7B")
26
- } else {
27
- ("Qwen/Qwen2.5-0.5B-Instruct", "0.5B")
28
- };
29
-
30
- let device = if std::env::var("ENABLE_ROCM").unwrap_or_default() == "1" {
31
- Device::new_hip(0).unwrap_or_else(|e| {
32
- tracing::warn!("HIP device not available: {}; falling back to CPU", e);
33
- Device::Cpu
34
- })
35
- } else {
36
- Device::Cpu
37
- };
38
-
39
- tracing::info!("Loading model {} on {:?}", model_id, device);
40
-
41
- let api = Api::new()?;
42
- let repo = api.model(model_id.to_string());
43
- let model_path = repo.get("model.safetensors")?;
44
- let config_path = repo.get("config.json")?;
45
- let tokenizer_path = repo.get("tokenizer.json")?;
46
-
47
- let config: Config = serde_json::from_reader(std::fs::File::open(config_path)?)?;
48
- let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device)? };
49
- let model = Model::new(&config, vb)?;
50
- let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(|e| anyhow::anyhow!(e))?;
51
-
52
- Ok(Arc::new(LoadedModel { model, tokenizer, device, model_name: model_name.to_string() }))
53
- })
54
- .await
55
- .map(Arc::clone)
56
- }
57
 
 
58
  /// Returns (generated_text, model_used, device_info)
59
- pub async fn generate(redacted_prompt: &str) -> Result<(String, String, String)> {
60
- match load_model().await {
61
- Ok(loaded) => {
62
- let prompt = format!("<|im_start|>user\n{}\n<|im_end|>\n<|im_start|>assistant\n", redacted_prompt);
63
- let tokens = loaded.tokenizer.encode(prompt, true).map_err(|e| anyhow::anyhow!(e))?;
64
- let input_ids = Tensor::new(tokens.get_ids(), &loaded.device)?.unsqueeze(0)?;
65
- let mut output_ids = input_ids.clone();
66
- let mut logits_processor = LogitsProcessor::from_sampling(Sampling::TopKTopP {
67
- k: 50,
68
- p: 0.9,
69
- temperature: 0.7,
70
- });
71
- let eos_token_id = loaded.tokenizer.token_to_id("<|im_end|>").unwrap_or(151643);
72
- let max_new_tokens = 250;
73
- let mut generated_text = String::new();
74
-
75
- // Candle currently recomputes full attention for each token.
76
- // A KV cache would speed this up and is the first post‑hackathon optimisation.
77
- // For real‑time streaming (SSE), the loop can yield tokens as they are sampled.
78
- for _ in 0..max_new_tokens {
79
- let logits = loaded.model.forward(&output_ids)?.squeeze(1)?;
80
- let next_token = logits_processor.sample(&logits)?;
81
- if next_token == eos_token_id {
82
- break;
83
- }
84
- output_ids = Tensor::cat(&[output_ids, next_token.unsqueeze(0)?.unsqueeze(0)?], 1)?;
85
- if let Ok(text) = loaded.tokenizer.decode(&[next_token as u32], false) {
86
- generated_text.push_str(&text);
87
- }
88
- }
89
-
90
- let device_info = format!("{:?}", loaded.device);
91
- if generated_text.is_empty() {
92
- Ok(("Unable to generate output.".to_string(), loaded.model_name.clone(), device_info))
93
- } else {
94
- Ok((generated_text.trim().to_string(), loaded.model_name.clone(), device_info))
95
- }
96
- }
97
- Err(e) => {
98
- tracing::warn!("Model load failed: {}; falling back to mock", e);
99
- Ok(("Triage result: non‑urgent (mock – model unavailable)".to_string(), "mock".to_string(), "CPU (fallback)".to_string()))
100
- }
101
- }
102
  }
 
1
+ use anyhow::Result;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ /// Mock inference for local testing / HF Space CPU.
4
  /// Returns (generated_text, model_used, device_info)
5
+ pub async fn generate(_redacted_prompt: &str) -> Result<(String, String, String)> {
6
+ tracing::info!("[MOCK] Inference skipped – returning placeholder");
7
+ tokio::time::sleep(std::time::Duration::from_millis(10)).await;
8
+ Ok((
9
+ "Triage result: non‑urgent (mock)".to_string(),
10
+ "mock".to_string(),
11
+ "CPU (mock)".to_string(),
12
+ ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  }